[ir][spirv-writer] VarForDynamicIndex transform
Look for access instructions that dynamically index into array and
matrix values, and replace them with accesses into a locally declared
copy of the value.
Use the transform in the SPIR-V writer.
Bug: tint:1906
Change-Id: Ie3872d8cf005948f64e73c455f5a38bfdab692c4
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/136360
Reviewed-by: Ben Clayton <bclayton@google.com>
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 1ca8658..bd06461 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -523,6 +523,8 @@
"ir/transform/add_empty_entry_point.h",
"ir/transform/block_decorated_structs.cc",
"ir/transform/block_decorated_structs.h",
+ "ir/transform/var_for_dynamic_index.cc",
+ "ir/transform/var_for_dynamic_index.h",
]
deps = [
":libtint_builtins_src",
@@ -1818,6 +1820,7 @@
"ir/transform/add_empty_entry_point_test.cc",
"ir/transform/block_decorated_structs_test.cc",
"ir/transform/test_helper.h",
+ "ir/transform/var_for_dynamic_index_test.cc",
]
deps = [
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index af06ca5..764be39 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -804,6 +804,8 @@
ir/transform/block_decorated_structs.h
ir/transform/transform.cc
ir/transform/transform.h
+ ir/transform/var_for_dynamic_index.cc
+ ir/transform/var_for_dynamic_index.h
)
endif()
@@ -1537,6 +1539,7 @@
ir/swizzle_test.cc
ir/transform/add_empty_entry_point_test.cc
ir/transform/block_decorated_structs_test.cc
+ ir/transform/var_for_dynamic_index_test.cc
ir/unary_test.cc
ir/user_call_test.cc
ir/validate_test.cc
diff --git a/src/tint/ir/access.h b/src/tint/ir/access.h
index c798b13..55f7306 100644
--- a/src/tint/ir/access.h
+++ b/src/tint/ir/access.h
@@ -41,6 +41,9 @@
return operands_.Slice().Offset(1).Reinterpret<Value const* const>();
}
+ /// @returns the accessor indices
+ utils::Slice<Value*> Indices() { return operands_.Slice().Offset(1); }
+
private:
const type::Type* result_type_ = nullptr;
};
diff --git a/src/tint/ir/transform/var_for_dynamic_index.cc b/src/tint/ir/transform/var_for_dynamic_index.cc
new file mode 100644
index 0000000..fa063e7
--- /dev/null
+++ b/src/tint/ir/transform/var_for_dynamic_index.cc
@@ -0,0 +1,174 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/transform/var_for_dynamic_index.h"
+
+#include <utility>
+
+#include "src/tint/ir/builder.h"
+#include "src/tint/ir/module.h"
+#include "src/tint/switch.h"
+#include "src/tint/type/array.h"
+#include "src/tint/type/matrix.h"
+#include "src/tint/type/pointer.h"
+#include "src/tint/type/struct.h"
+#include "src/tint/type/vector.h"
+#include "src/tint/utils/hashmap.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ir::transform::VarForDynamicIndex);
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ir::transform {
+
+namespace {
+// An access that needs replacing.
+struct AccessToReplace {
+ // The access instruction.
+ Access* access = nullptr;
+ // The index of the first dynamic index.
+ uint32_t first_dynamic_index = 0;
+ // The object type that corresponds to the source of the first dynamic index.
+ const type::Type* dynamic_index_source_type = nullptr;
+};
+
+// A partial access chain that uses constant indices to get to an object that will be
+// dynamically indexed.
+struct PartialAccess {
+ // The base object.
+ Value* base = nullptr;
+ // The list of constant indices to get from the base to the source object.
+ utils::Vector<Value*, 4> indices;
+
+ // A specialization of utils::Hasher for PartialAccess.
+ struct Hasher {
+ inline std::size_t operator()(const PartialAccess& src) const {
+ return utils::Hash(src.base, src.indices);
+ }
+ };
+
+ // An equality helper for PartialAccess.
+ bool operator==(const PartialAccess& other) const {
+ return base == other.base && indices == other.indices;
+ }
+};
+} // namespace
+
+VarForDynamicIndex::VarForDynamicIndex() = default;
+
+VarForDynamicIndex::~VarForDynamicIndex() = default;
+
+static std::optional<AccessToReplace> ShouldReplace(Access* access) {
+ AccessToReplace to_replace{access, 0, access->Object()->Type()};
+
+ // Find the first dynamic index, if any.
+ bool has_dynamic_index = false;
+ for (auto* idx : access->Indices()) {
+ if (to_replace.dynamic_index_source_type->Is<type::Vector>()) {
+ // Stop if we hit a vector, as they can support dynamic accesses.
+ break;
+ }
+
+ // Check if the index is dynamic.
+ auto* const_idx = idx->As<Constant>();
+ if (!const_idx) {
+ has_dynamic_index = true;
+ break;
+ }
+ to_replace.first_dynamic_index++;
+
+ // Update the current object type.
+ to_replace.dynamic_index_source_type = tint::Switch(
+ to_replace.dynamic_index_source_type, //
+ [&](const type::Array* arr) { return arr->ElemType(); },
+ [&](const type::Matrix* mat) { return mat->ColumnType(); },
+ [&](const type::Struct* str) {
+ return str->Members()[const_idx->Value()->ValueAs<u32>()]->Type();
+ },
+ [&](const type::Vector* vec) { return vec->type(); }, //
+ [&](Default) { return nullptr; });
+ }
+ if (!has_dynamic_index) {
+ // No need to modify accesses that only use constant indices.
+ return {};
+ }
+
+ return to_replace;
+}
+
+void VarForDynamicIndex::Run(ir::Module* ir, const DataMap&, DataMap&) const {
+ ir::Builder builder(*ir);
+
+ // Find the access instructions that need replacing.
+ utils::Vector<AccessToReplace, 4> worklist;
+ for (auto* inst : ir->values.Objects()) {
+ auto* access = inst->As<Access>();
+ if (access && !access->Type()->Is<type::Pointer>()) {
+ if (auto to_replace = ShouldReplace(access)) {
+ worklist.Push(to_replace.value());
+ }
+ }
+ }
+
+ // Replace each access instruction that we recorded.
+ utils::Hashmap<Value*, Value*, 4> object_to_local;
+ utils::Hashmap<PartialAccess, Value*, 4, PartialAccess::Hasher> source_object_to_value;
+ for (const auto& to_replace : worklist) {
+ auto* access = to_replace.access;
+ Value* source_object = access->Object();
+
+ // If the access starts with at least one constant index, extract the source of the first
+ // dynamic access to avoid copying the whole object.
+ if (to_replace.first_dynamic_index > 0) {
+ PartialAccess partial_access = {
+ access->Object(), access->Indices().Truncate(to_replace.first_dynamic_index)};
+ source_object = source_object_to_value.GetOrCreate(partial_access, [&]() {
+ auto* intermediate_source = builder.Access(to_replace.dynamic_index_source_type,
+ source_object, partial_access.indices);
+ intermediate_source->InsertBefore(access);
+ return intermediate_source;
+ });
+ }
+
+ // Declare a local variable and copy the source object to it.
+ auto* local = object_to_local.GetOrCreate(source_object, [&]() {
+ auto* decl = builder.Declare(ir->Types().pointer(to_replace.dynamic_index_source_type,
+ builtin::AddressSpace::kFunction,
+ builtin::Access::kReadWrite));
+ decl->SetInitializer(source_object);
+ decl->InsertBefore(access);
+ return decl;
+ });
+
+ // Create a new access instruction using the local variable as the source.
+ utils::Vector<Value*, 4> indices{access->Indices().Offset(to_replace.first_dynamic_index)};
+ auto* new_access =
+ builder.Access(ir->Types().pointer(access->Type(), builtin::AddressSpace::kFunction,
+ builtin::Access::kReadWrite),
+ local, indices);
+ access->ReplaceWith(new_access);
+
+ // Load from the access to get the final result value.
+ auto* load = builder.Load(new_access);
+ load->InsertAfter(new_access);
+
+ // Replace all uses of the old access instruction with the loaded result.
+ while (!access->Usages().IsEmpty()) {
+ auto& use = *access->Usages().begin();
+ use.instruction->SetOperand(use.operand_index, load);
+ }
+ }
+}
+
+} // namespace tint::ir::transform
diff --git a/src/tint/ir/transform/var_for_dynamic_index.h b/src/tint/ir/transform/var_for_dynamic_index.h
new file mode 100644
index 0000000..1f86b7d
--- /dev/null
+++ b/src/tint/ir/transform/var_for_dynamic_index.h
@@ -0,0 +1,39 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_TRANSFORM_VAR_FOR_DYNAMIC_INDEX_H_
+#define SRC_TINT_IR_TRANSFORM_VAR_FOR_DYNAMIC_INDEX_H_
+
+#include "src/tint/ir/transform/transform.h"
+
+namespace tint::ir::transform {
+
+/// VarForDynamicIndex is a transform that copies array and matrix values that are dynamically
+/// indexed to a temporary local `var` before performing the index. This transform is used by the
+/// SPIR-V writer as there is no SPIR-V instruction that can dynamically index a non-pointer
+/// composite.
+class VarForDynamicIndex final : public utils::Castable<VarForDynamicIndex, Transform> {
+ public:
+ /// Constructor
+ VarForDynamicIndex();
+ /// Destructor
+ ~VarForDynamicIndex() override;
+
+ /// @copydoc Transform::Run
+ void Run(ir::Module* module, const DataMap& inputs, DataMap& outputs) const override;
+};
+
+} // namespace tint::ir::transform
+
+#endif // SRC_TINT_IR_TRANSFORM_VAR_FOR_DYNAMIC_INDEX_H_
diff --git a/src/tint/ir/transform/var_for_dynamic_index_test.cc b/src/tint/ir/transform/var_for_dynamic_index_test.cc
new file mode 100644
index 0000000..0f6af06
--- /dev/null
+++ b/src/tint/ir/transform/var_for_dynamic_index_test.cc
@@ -0,0 +1,428 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/transform/var_for_dynamic_index.h"
+
+#include <utility>
+
+#include "src/tint/ir/transform/test_helper.h"
+#include "src/tint/type/array.h"
+#include "src/tint/type/matrix.h"
+#include "src/tint/type/struct.h"
+
+namespace tint::ir::transform {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+
+class IR_VarForDynamicIndexTest : public TransformTest {
+ protected:
+ const type::Type* ptr(const type::Type* elem) {
+ return ty.pointer(elem, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite);
+ }
+};
+
+TEST_F(IR_VarForDynamicIndexTest, NoModify_ConstantIndex_ArrayValue) {
+ auto* arr = b.FunctionParam(ty.array(ty.i32(), 4u));
+ auto* func = b.CreateFunction("foo", ty.i32());
+ func->SetParams(utils::Vector{arr});
+
+ auto* access = b.Access(ty.i32(), arr, utils::Vector{b.Constant(1_i)});
+ func->StartTarget()->Append(access);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func(%2:array<i32, 4>):i32 -> %b1 {
+ %b1 = block {
+ %3:i32 = access %2, 1i
+ ret %3
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VarForDynamicIndexTest, NoModify_ConstantIndex_MatrixValue) {
+ auto* mat = b.FunctionParam(ty.mat2x2(ty.f32()));
+ auto* func = b.CreateFunction("foo", ty.f32());
+ func->SetParams(utils::Vector{mat});
+
+ auto* access = b.Access(ty.f32(), mat, utils::Vector{b.Constant(1_i), b.Constant(0_i)});
+ func->StartTarget()->Append(access);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func(%2:mat2x2<f32>):f32 -> %b1 {
+ %b1 = block {
+ %3:f32 = access %2, 1i, 0i
+ ret %3
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VarForDynamicIndexTest, NoModify_DynamicIndex_ArrayPointer) {
+ auto* arr = b.FunctionParam(ptr(ty.array(ty.i32(), 4u)));
+ auto* idx = b.FunctionParam(ty.i32());
+ auto* func = b.CreateFunction("foo", ty.i32());
+ func->SetParams(utils::Vector{arr, idx});
+
+ auto* access = b.Access(ptr(ty.i32()), arr, utils::Vector{idx});
+ auto* load = b.Load(access);
+ func->StartTarget()->Append(access);
+ func->StartTarget()->Append(load);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{load}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func(%2:ptr<function, array<i32, 4>, read_write>, %3:i32):i32 -> %b1 {
+ %b1 = block {
+ %4:ptr<function, i32, read_write> = access %2, %3
+ %5:i32 = load %4
+ ret %5
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VarForDynamicIndexTest, NoModify_DynamicIndex_MatrixPointer) {
+ auto* mat = b.FunctionParam(ptr(ty.mat2x2(ty.f32())));
+ auto* idx = b.FunctionParam(ty.i32());
+ auto* func = b.CreateFunction("foo", ty.f32());
+ func->SetParams(utils::Vector{mat, idx});
+
+ auto* access = b.Access(ptr(ty.f32()), mat, utils::Vector{idx, idx});
+ auto* load = b.Load(access);
+ func->StartTarget()->Append(access);
+ func->StartTarget()->Append(load);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{load}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func(%2:ptr<function, mat2x2<f32>, read_write>, %3:i32):f32 -> %b1 {
+ %b1 = block {
+ %4:ptr<function, f32, read_write> = access %2, %3, %3
+ %5:f32 = load %4
+ ret %5
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VarForDynamicIndexTest, NoModify_DynamicIndex_VectorValue) {
+ auto* vec = b.FunctionParam(ty.vec4(ty.f32()));
+ auto* idx = b.FunctionParam(ty.i32());
+ auto* func = b.CreateFunction("foo", ty.f32());
+ func->SetParams(utils::Vector{vec, idx});
+
+ auto* access = b.Access(ty.f32(), vec, utils::Vector{idx});
+ func->StartTarget()->Append(access);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func(%2:vec4<f32>, %3:i32):f32 -> %b1 {
+ %b1 = block {
+ %4:f32 = access %2, %3
+ ret %4
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VarForDynamicIndexTest, DynamicIndex_ArrayValue) {
+ auto* arr = b.FunctionParam(ty.array(ty.i32(), 4u));
+ auto* idx = b.FunctionParam(ty.i32());
+ auto* func = b.CreateFunction("foo", ty.i32());
+ func->SetParams(utils::Vector{arr, idx});
+
+ auto* access = b.Access(ty.i32(), arr, utils::Vector{idx});
+ func->StartTarget()->Append(access);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func(%2:array<i32, 4>, %3:i32):i32 -> %b1 {
+ %b1 = block {
+ %4:ptr<function, array<i32, 4>, read_write> = var, %2
+ %5:ptr<function, i32, read_write> = access %4, %3
+ %6:i32 = load %5
+ ret %6
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VarForDynamicIndexTest, DynamicIndex_MatrixValue) {
+ auto* arr = b.FunctionParam(ty.mat2x2(ty.f32()));
+ auto* idx = b.FunctionParam(ty.i32());
+ auto* func = b.CreateFunction("foo", ty.f32());
+ func->SetParams(utils::Vector{arr, idx});
+
+ auto* access = b.Access(ty.f32(), arr, utils::Vector{idx});
+ func->StartTarget()->Append(access);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func(%2:mat2x2<f32>, %3:i32):f32 -> %b1 {
+ %b1 = block {
+ %4:ptr<function, mat2x2<f32>, read_write> = var, %2
+ %5:ptr<function, f32, read_write> = access %4, %3
+ %6:f32 = load %5
+ ret %6
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VarForDynamicIndexTest, AccessChain) {
+ auto* arr = b.FunctionParam(ty.array(ty.array(ty.array(ty.i32(), 4u), 4u), 4u));
+ auto* idx = b.FunctionParam(ty.i32());
+ auto* func = b.CreateFunction("foo", ty.i32());
+ func->SetParams(utils::Vector{arr, idx});
+
+ auto* access = b.Access(ty.i32(), arr, utils::Vector{idx, b.Constant(1_u), idx});
+ func->StartTarget()->Append(access);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func(%2:array<array<array<i32, 4>, 4>, 4>, %3:i32):i32 -> %b1 {
+ %b1 = block {
+ %4:ptr<function, array<array<array<i32, 4>, 4>, 4>, read_write> = var, %2
+ %5:ptr<function, i32, read_write> = access %4, %3, 1u, %3
+ %6:i32 = load %5
+ ret %6
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VarForDynamicIndexTest, AccessChain_SkipConstantIndices) {
+ auto* arr = b.FunctionParam(ty.array(ty.array(ty.array(ty.i32(), 4u), 4u), 4u));
+ auto* idx = b.FunctionParam(ty.i32());
+ auto* func = b.CreateFunction("foo", ty.i32());
+ func->SetParams(utils::Vector{arr, idx});
+
+ auto* access = b.Access(ty.i32(), arr, utils::Vector{b.Constant(1_u), b.Constant(2_u), idx});
+ func->StartTarget()->Append(access);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func(%2:array<array<array<i32, 4>, 4>, 4>, %3:i32):i32 -> %b1 {
+ %b1 = block {
+ %4:array<i32, 4> = access %2, 1u, 2u
+ %5:ptr<function, array<i32, 4>, read_write> = var, %4
+ %6:ptr<function, i32, read_write> = access %5, %3
+ %7:i32 = load %6
+ ret %7
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VarForDynamicIndexTest, AccessChain_SkipConstantIndices_Interleaved) {
+ auto* arr = b.FunctionParam(ty.array(ty.array(ty.array(ty.array(ty.i32(), 4u), 4u), 4u), 4u));
+ auto* idx = b.FunctionParam(ty.i32());
+ auto* func = b.CreateFunction("foo", ty.i32());
+ func->SetParams(utils::Vector{arr, idx});
+
+ auto* access =
+ b.Access(ty.i32(), arr, utils::Vector{b.Constant(1_u), idx, b.Constant(2_u), idx});
+ func->StartTarget()->Append(access);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func(%2:array<array<array<array<i32, 4>, 4>, 4>, 4>, %3:i32):i32 -> %b1 {
+ %b1 = block {
+ %4:array<array<array<i32, 4>, 4>, 4> = access %2, 1u
+ %5:ptr<function, array<array<array<i32, 4>, 4>, 4>, read_write> = var, %4
+ %6:ptr<function, i32, read_write> = access %5, %3, 2u, %3
+ %7:i32 = load %6
+ ret %7
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VarForDynamicIndexTest, AccessChain_SkipConstantIndices_Struct) {
+ auto* str_ty = ty.Get<type::Struct>(
+ mod.symbols.Register("MyStruct"),
+ utils::Vector{
+ ty.Get<type::StructMember>(mod.symbols.Register("arr1"), ty.array(ty.f32(), 1024u), 0u,
+ 0u, 4u, 4096u, type::StructMemberAttributes{}),
+ ty.Get<type::StructMember>(mod.symbols.Register("mat"), ty.mat4x4(ty.f32()), 1u, 4096u,
+ 16u, 64u, type::StructMemberAttributes{}),
+ ty.Get<type::StructMember>(mod.symbols.Register("arr2"), ty.array(ty.f32(), 1024u), 2u,
+ 4160u, 4u, 4096u, type::StructMemberAttributes{}),
+ },
+ 16u, 32u, 32u);
+ auto* str_val = b.FunctionParam(str_ty);
+ auto* idx = b.FunctionParam(ty.i32());
+ auto* func = b.CreateFunction("foo", ty.f32());
+ func->SetParams(utils::Vector{str_val, idx});
+
+ auto* access =
+ b.Access(ty.f32(), str_val, utils::Vector{b.Constant(1_u), idx, b.Constant(0_u)});
+ func->StartTarget()->Append(access);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+MyStruct = struct @align(16) {
+ arr1:array<f32, 1024> @offset(0)
+ mat:mat4x4<f32> @offset(4096)
+ arr2:array<f32, 1024> @offset(4160)
+}
+
+%foo = func(%2:MyStruct, %3:i32):f32 -> %b1 {
+ %b1 = block {
+ %4:mat4x4<f32> = access %2, 1u
+ %5:ptr<function, mat4x4<f32>, read_write> = var, %4
+ %6:ptr<function, f32, read_write> = access %5, %3, 0u
+ %7:f32 = load %6
+ ret %7
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VarForDynamicIndexTest, MultipleAccessesFromSameSource) {
+ auto* arr = b.FunctionParam(ty.array(ty.i32(), 4u));
+ auto* idx_a = b.FunctionParam(ty.i32());
+ auto* idx_b = b.FunctionParam(ty.i32());
+ auto* idx_c = b.FunctionParam(ty.i32());
+ auto* func = b.CreateFunction("foo", ty.i32());
+ func->SetParams(utils::Vector{arr, idx_a, idx_b, idx_c});
+
+ auto* access_a = b.Access(ty.i32(), arr, utils::Vector{idx_a});
+ auto* access_b = b.Access(ty.i32(), arr, utils::Vector{idx_b});
+ auto* access_c = b.Access(ty.i32(), arr, utils::Vector{idx_c});
+ func->StartTarget()->Append(access_a);
+ func->StartTarget()->Append(access_b);
+ func->StartTarget()->Append(access_c);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{access_c}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func(%2:array<i32, 4>, %3:i32, %4:i32, %5:i32):i32 -> %b1 {
+ %b1 = block {
+ %6:ptr<function, array<i32, 4>, read_write> = var, %2
+ %7:ptr<function, i32, read_write> = access %6, %3
+ %8:i32 = load %7
+ %9:ptr<function, i32, read_write> = access %6, %4
+ %10:i32 = load %9
+ %11:ptr<function, i32, read_write> = access %6, %5
+ %12:i32 = load %11
+ ret %12
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VarForDynamicIndexTest, MultipleAccessesFromSameSource_SkipConstantIndices) {
+ auto* arr = b.FunctionParam(ty.array(ty.array(ty.array(ty.i32(), 4u), 4u), 4u));
+ auto* idx_a = b.FunctionParam(ty.i32());
+ auto* idx_b = b.FunctionParam(ty.i32());
+ auto* idx_c = b.FunctionParam(ty.i32());
+ auto* func = b.CreateFunction("foo", ty.i32());
+ func->SetParams(utils::Vector{arr, idx_a, idx_b, idx_c});
+
+ auto* access_a =
+ b.Access(ty.i32(), arr, utils::Vector{b.Constant(1_u), b.Constant(2_u), idx_a});
+ auto* access_b =
+ b.Access(ty.i32(), arr, utils::Vector{b.Constant(1_u), b.Constant(2_u), idx_b});
+ auto* access_c =
+ b.Access(ty.i32(), arr, utils::Vector{b.Constant(1_u), b.Constant(2_u), idx_c});
+ func->StartTarget()->Append(access_a);
+ func->StartTarget()->Append(access_b);
+ func->StartTarget()->Append(access_c);
+ func->StartTarget()->Append(b.Return(func, utils::Vector{access_c}));
+ mod.functions.Push(func);
+
+ auto* expect = R"(
+%foo = func(%2:array<array<array<i32, 4>, 4>, 4>, %3:i32, %4:i32, %5:i32):i32 -> %b1 {
+ %b1 = block {
+ %6:array<i32, 4> = access %2, 1u, 2u
+ %7:ptr<function, array<i32, 4>, read_write> = var, %6
+ %8:ptr<function, i32, read_write> = access %7, %3
+ %9:i32 = load %8
+ %10:ptr<function, i32, read_write> = access %7, %4
+ %11:i32 = load %10
+ %12:ptr<function, i32, read_write> = access %7, %5
+ %13:i32 = load %12
+ ret %13
+ }
+}
+)";
+
+ Run<VarForDynamicIndex>();
+
+ EXPECT_EQ(expect, str());
+}
+
+} // namespace
+} // namespace tint::ir::transform
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc
index 773c96f..a173b49 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc
@@ -38,6 +38,7 @@
#include "src/tint/ir/switch.h"
#include "src/tint/ir/transform/add_empty_entry_point.h"
#include "src/tint/ir/transform/block_decorated_structs.h"
+#include "src/tint/ir/transform/var_for_dynamic_index.h"
#include "src/tint/ir/user_call.h"
#include "src/tint/ir/validate.h"
#include "src/tint/ir/var.h"
@@ -68,6 +69,7 @@
manager.Add<ir::transform::AddEmptyEntryPoint>();
manager.Add<ir::transform::BlockDecoratedStructs>();
+ manager.Add<ir::transform::VarForDynamicIndex>();
transform::DataMap outputs;
manager.Run(module, data, outputs);
@@ -566,7 +568,6 @@
// For non-pointer types, we assume that the indices are constants and use OpCompositeExtract.
// If we hit a non-constant index into a vector type, use OpVectorExtractDynamic for it.
- // TODO(jrprice): Port VarForDynamicIndex transform to IR to make the above assertion true.
auto* ty = access->Object()->Type();
for (auto* idx : access->Indices()) {
if (auto* constant = idx->As<ir::Constant>()) {