[ir][spirv-writer] VarForDynamicIndex transform Look for access instructions that dynamically index into array and matrix values, and replace them with accesses into a locally declared copy of the value. Use the transform in the SPIR-V writer. Bug: tint:1906 Change-Id: Ie3872d8cf005948f64e73c455f5a38bfdab692c4 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/136360 Reviewed-by: Ben Clayton <bclayton@google.com> Auto-Submit: James Price <jrprice@google.com> Reviewed-by: Dan Sinclair <dsinclair@chromium.org> Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index 1ca8658..bd06461 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn
@@ -523,6 +523,8 @@ "ir/transform/add_empty_entry_point.h", "ir/transform/block_decorated_structs.cc", "ir/transform/block_decorated_structs.h", + "ir/transform/var_for_dynamic_index.cc", + "ir/transform/var_for_dynamic_index.h", ] deps = [ ":libtint_builtins_src", @@ -1818,6 +1820,7 @@ "ir/transform/add_empty_entry_point_test.cc", "ir/transform/block_decorated_structs_test.cc", "ir/transform/test_helper.h", + "ir/transform/var_for_dynamic_index_test.cc", ] deps = [
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index af06ca5..764be39 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt
@@ -804,6 +804,8 @@ ir/transform/block_decorated_structs.h ir/transform/transform.cc ir/transform/transform.h + ir/transform/var_for_dynamic_index.cc + ir/transform/var_for_dynamic_index.h ) endif() @@ -1537,6 +1539,7 @@ ir/swizzle_test.cc ir/transform/add_empty_entry_point_test.cc ir/transform/block_decorated_structs_test.cc + ir/transform/var_for_dynamic_index_test.cc ir/unary_test.cc ir/user_call_test.cc ir/validate_test.cc
diff --git a/src/tint/ir/access.h b/src/tint/ir/access.h index c798b13..55f7306 100644 --- a/src/tint/ir/access.h +++ b/src/tint/ir/access.h
@@ -41,6 +41,9 @@ return operands_.Slice().Offset(1).Reinterpret<Value const* const>(); } + /// @returns the accessor indices + utils::Slice<Value*> Indices() { return operands_.Slice().Offset(1); } + private: const type::Type* result_type_ = nullptr; };
diff --git a/src/tint/ir/transform/var_for_dynamic_index.cc b/src/tint/ir/transform/var_for_dynamic_index.cc new file mode 100644 index 0000000..fa063e7 --- /dev/null +++ b/src/tint/ir/transform/var_for_dynamic_index.cc
@@ -0,0 +1,174 @@ +// Copyright 2023 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/ir/transform/var_for_dynamic_index.h" + +#include <utility> + +#include "src/tint/ir/builder.h" +#include "src/tint/ir/module.h" +#include "src/tint/switch.h" +#include "src/tint/type/array.h" +#include "src/tint/type/matrix.h" +#include "src/tint/type/pointer.h" +#include "src/tint/type/struct.h" +#include "src/tint/type/vector.h" +#include "src/tint/utils/hashmap.h" + +TINT_INSTANTIATE_TYPEINFO(tint::ir::transform::VarForDynamicIndex); + +using namespace tint::number_suffixes; // NOLINT + +namespace tint::ir::transform { + +namespace { +// An access that needs replacing. +struct AccessToReplace { + // The access instruction. + Access* access = nullptr; + // The index of the first dynamic index. + uint32_t first_dynamic_index = 0; + // The object type that corresponds to the source of the first dynamic index. + const type::Type* dynamic_index_source_type = nullptr; +}; + +// A partial access chain that uses constant indices to get to an object that will be +// dynamically indexed. +struct PartialAccess { + // The base object. + Value* base = nullptr; + // The list of constant indices to get from the base to the source object. + utils::Vector<Value*, 4> indices; + + // A specialization of utils::Hasher for PartialAccess. + struct Hasher { + inline std::size_t operator()(const PartialAccess& src) const { + return utils::Hash(src.base, src.indices); + } + }; + + // An equality helper for PartialAccess. + bool operator==(const PartialAccess& other) const { + return base == other.base && indices == other.indices; + } +}; +} // namespace + +VarForDynamicIndex::VarForDynamicIndex() = default; + +VarForDynamicIndex::~VarForDynamicIndex() = default; + +static std::optional<AccessToReplace> ShouldReplace(Access* access) { + AccessToReplace to_replace{access, 0, access->Object()->Type()}; + + // Find the first dynamic index, if any. + bool has_dynamic_index = false; + for (auto* idx : access->Indices()) { + if (to_replace.dynamic_index_source_type->Is<type::Vector>()) { + // Stop if we hit a vector, as they can support dynamic accesses. + break; + } + + // Check if the index is dynamic. + auto* const_idx = idx->As<Constant>(); + if (!const_idx) { + has_dynamic_index = true; + break; + } + to_replace.first_dynamic_index++; + + // Update the current object type. + to_replace.dynamic_index_source_type = tint::Switch( + to_replace.dynamic_index_source_type, // + [&](const type::Array* arr) { return arr->ElemType(); }, + [&](const type::Matrix* mat) { return mat->ColumnType(); }, + [&](const type::Struct* str) { + return str->Members()[const_idx->Value()->ValueAs<u32>()]->Type(); + }, + [&](const type::Vector* vec) { return vec->type(); }, // + [&](Default) { return nullptr; }); + } + if (!has_dynamic_index) { + // No need to modify accesses that only use constant indices. + return {}; + } + + return to_replace; +} + +void VarForDynamicIndex::Run(ir::Module* ir, const DataMap&, DataMap&) const { + ir::Builder builder(*ir); + + // Find the access instructions that need replacing. + utils::Vector<AccessToReplace, 4> worklist; + for (auto* inst : ir->values.Objects()) { + auto* access = inst->As<Access>(); + if (access && !access->Type()->Is<type::Pointer>()) { + if (auto to_replace = ShouldReplace(access)) { + worklist.Push(to_replace.value()); + } + } + } + + // Replace each access instruction that we recorded. + utils::Hashmap<Value*, Value*, 4> object_to_local; + utils::Hashmap<PartialAccess, Value*, 4, PartialAccess::Hasher> source_object_to_value; + for (const auto& to_replace : worklist) { + auto* access = to_replace.access; + Value* source_object = access->Object(); + + // If the access starts with at least one constant index, extract the source of the first + // dynamic access to avoid copying the whole object. + if (to_replace.first_dynamic_index > 0) { + PartialAccess partial_access = { + access->Object(), access->Indices().Truncate(to_replace.first_dynamic_index)}; + source_object = source_object_to_value.GetOrCreate(partial_access, [&]() { + auto* intermediate_source = builder.Access(to_replace.dynamic_index_source_type, + source_object, partial_access.indices); + intermediate_source->InsertBefore(access); + return intermediate_source; + }); + } + + // Declare a local variable and copy the source object to it. + auto* local = object_to_local.GetOrCreate(source_object, [&]() { + auto* decl = builder.Declare(ir->Types().pointer(to_replace.dynamic_index_source_type, + builtin::AddressSpace::kFunction, + builtin::Access::kReadWrite)); + decl->SetInitializer(source_object); + decl->InsertBefore(access); + return decl; + }); + + // Create a new access instruction using the local variable as the source. + utils::Vector<Value*, 4> indices{access->Indices().Offset(to_replace.first_dynamic_index)}; + auto* new_access = + builder.Access(ir->Types().pointer(access->Type(), builtin::AddressSpace::kFunction, + builtin::Access::kReadWrite), + local, indices); + access->ReplaceWith(new_access); + + // Load from the access to get the final result value. + auto* load = builder.Load(new_access); + load->InsertAfter(new_access); + + // Replace all uses of the old access instruction with the loaded result. + while (!access->Usages().IsEmpty()) { + auto& use = *access->Usages().begin(); + use.instruction->SetOperand(use.operand_index, load); + } + } +} + +} // namespace tint::ir::transform
diff --git a/src/tint/ir/transform/var_for_dynamic_index.h b/src/tint/ir/transform/var_for_dynamic_index.h new file mode 100644 index 0000000..1f86b7d --- /dev/null +++ b/src/tint/ir/transform/var_for_dynamic_index.h
@@ -0,0 +1,39 @@ +// Copyright 2023 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TINT_IR_TRANSFORM_VAR_FOR_DYNAMIC_INDEX_H_ +#define SRC_TINT_IR_TRANSFORM_VAR_FOR_DYNAMIC_INDEX_H_ + +#include "src/tint/ir/transform/transform.h" + +namespace tint::ir::transform { + +/// VarForDynamicIndex is a transform that copies array and matrix values that are dynamically +/// indexed to a temporary local `var` before performing the index. This transform is used by the +/// SPIR-V writer as there is no SPIR-V instruction that can dynamically index a non-pointer +/// composite. +class VarForDynamicIndex final : public utils::Castable<VarForDynamicIndex, Transform> { + public: + /// Constructor + VarForDynamicIndex(); + /// Destructor + ~VarForDynamicIndex() override; + + /// @copydoc Transform::Run + void Run(ir::Module* module, const DataMap& inputs, DataMap& outputs) const override; +}; + +} // namespace tint::ir::transform + +#endif // SRC_TINT_IR_TRANSFORM_VAR_FOR_DYNAMIC_INDEX_H_
diff --git a/src/tint/ir/transform/var_for_dynamic_index_test.cc b/src/tint/ir/transform/var_for_dynamic_index_test.cc new file mode 100644 index 0000000..0f6af06 --- /dev/null +++ b/src/tint/ir/transform/var_for_dynamic_index_test.cc
@@ -0,0 +1,428 @@ +// Copyright 2023 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/ir/transform/var_for_dynamic_index.h" + +#include <utility> + +#include "src/tint/ir/transform/test_helper.h" +#include "src/tint/type/array.h" +#include "src/tint/type/matrix.h" +#include "src/tint/type/struct.h" + +namespace tint::ir::transform { +namespace { + +using namespace tint::number_suffixes; // NOLINT + +class IR_VarForDynamicIndexTest : public TransformTest { + protected: + const type::Type* ptr(const type::Type* elem) { + return ty.pointer(elem, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite); + } +}; + +TEST_F(IR_VarForDynamicIndexTest, NoModify_ConstantIndex_ArrayValue) { + auto* arr = b.FunctionParam(ty.array(ty.i32(), 4u)); + auto* func = b.CreateFunction("foo", ty.i32()); + func->SetParams(utils::Vector{arr}); + + auto* access = b.Access(ty.i32(), arr, utils::Vector{b.Constant(1_i)}); + func->StartTarget()->Append(access); + func->StartTarget()->Append(b.Return(func, utils::Vector{access})); + mod.functions.Push(func); + + auto* expect = R"( +%foo = func(%2:array<i32, 4>):i32 -> %b1 { + %b1 = block { + %3:i32 = access %2, 1i + ret %3 + } +} +)"; + + Run<VarForDynamicIndex>(); + + EXPECT_EQ(expect, str()); +} + +TEST_F(IR_VarForDynamicIndexTest, NoModify_ConstantIndex_MatrixValue) { + auto* mat = b.FunctionParam(ty.mat2x2(ty.f32())); + auto* func = b.CreateFunction("foo", ty.f32()); + func->SetParams(utils::Vector{mat}); + + auto* access = b.Access(ty.f32(), mat, utils::Vector{b.Constant(1_i), b.Constant(0_i)}); + func->StartTarget()->Append(access); + func->StartTarget()->Append(b.Return(func, utils::Vector{access})); + mod.functions.Push(func); + + auto* expect = R"( +%foo = func(%2:mat2x2<f32>):f32 -> %b1 { + %b1 = block { + %3:f32 = access %2, 1i, 0i + ret %3 + } +} +)"; + + Run<VarForDynamicIndex>(); + + EXPECT_EQ(expect, str()); +} + +TEST_F(IR_VarForDynamicIndexTest, NoModify_DynamicIndex_ArrayPointer) { + auto* arr = b.FunctionParam(ptr(ty.array(ty.i32(), 4u))); + auto* idx = b.FunctionParam(ty.i32()); + auto* func = b.CreateFunction("foo", ty.i32()); + func->SetParams(utils::Vector{arr, idx}); + + auto* access = b.Access(ptr(ty.i32()), arr, utils::Vector{idx}); + auto* load = b.Load(access); + func->StartTarget()->Append(access); + func->StartTarget()->Append(load); + func->StartTarget()->Append(b.Return(func, utils::Vector{load})); + mod.functions.Push(func); + + auto* expect = R"( +%foo = func(%2:ptr<function, array<i32, 4>, read_write>, %3:i32):i32 -> %b1 { + %b1 = block { + %4:ptr<function, i32, read_write> = access %2, %3 + %5:i32 = load %4 + ret %5 + } +} +)"; + + Run<VarForDynamicIndex>(); + + EXPECT_EQ(expect, str()); +} + +TEST_F(IR_VarForDynamicIndexTest, NoModify_DynamicIndex_MatrixPointer) { + auto* mat = b.FunctionParam(ptr(ty.mat2x2(ty.f32()))); + auto* idx = b.FunctionParam(ty.i32()); + auto* func = b.CreateFunction("foo", ty.f32()); + func->SetParams(utils::Vector{mat, idx}); + + auto* access = b.Access(ptr(ty.f32()), mat, utils::Vector{idx, idx}); + auto* load = b.Load(access); + func->StartTarget()->Append(access); + func->StartTarget()->Append(load); + func->StartTarget()->Append(b.Return(func, utils::Vector{load})); + mod.functions.Push(func); + + auto* expect = R"( +%foo = func(%2:ptr<function, mat2x2<f32>, read_write>, %3:i32):f32 -> %b1 { + %b1 = block { + %4:ptr<function, f32, read_write> = access %2, %3, %3 + %5:f32 = load %4 + ret %5 + } +} +)"; + + Run<VarForDynamicIndex>(); + + EXPECT_EQ(expect, str()); +} + +TEST_F(IR_VarForDynamicIndexTest, NoModify_DynamicIndex_VectorValue) { + auto* vec = b.FunctionParam(ty.vec4(ty.f32())); + auto* idx = b.FunctionParam(ty.i32()); + auto* func = b.CreateFunction("foo", ty.f32()); + func->SetParams(utils::Vector{vec, idx}); + + auto* access = b.Access(ty.f32(), vec, utils::Vector{idx}); + func->StartTarget()->Append(access); + func->StartTarget()->Append(b.Return(func, utils::Vector{access})); + mod.functions.Push(func); + + auto* expect = R"( +%foo = func(%2:vec4<f32>, %3:i32):f32 -> %b1 { + %b1 = block { + %4:f32 = access %2, %3 + ret %4 + } +} +)"; + + Run<VarForDynamicIndex>(); + + EXPECT_EQ(expect, str()); +} + +TEST_F(IR_VarForDynamicIndexTest, DynamicIndex_ArrayValue) { + auto* arr = b.FunctionParam(ty.array(ty.i32(), 4u)); + auto* idx = b.FunctionParam(ty.i32()); + auto* func = b.CreateFunction("foo", ty.i32()); + func->SetParams(utils::Vector{arr, idx}); + + auto* access = b.Access(ty.i32(), arr, utils::Vector{idx}); + func->StartTarget()->Append(access); + func->StartTarget()->Append(b.Return(func, utils::Vector{access})); + mod.functions.Push(func); + + auto* expect = R"( +%foo = func(%2:array<i32, 4>, %3:i32):i32 -> %b1 { + %b1 = block { + %4:ptr<function, array<i32, 4>, read_write> = var, %2 + %5:ptr<function, i32, read_write> = access %4, %3 + %6:i32 = load %5 + ret %6 + } +} +)"; + + Run<VarForDynamicIndex>(); + + EXPECT_EQ(expect, str()); +} + +TEST_F(IR_VarForDynamicIndexTest, DynamicIndex_MatrixValue) { + auto* arr = b.FunctionParam(ty.mat2x2(ty.f32())); + auto* idx = b.FunctionParam(ty.i32()); + auto* func = b.CreateFunction("foo", ty.f32()); + func->SetParams(utils::Vector{arr, idx}); + + auto* access = b.Access(ty.f32(), arr, utils::Vector{idx}); + func->StartTarget()->Append(access); + func->StartTarget()->Append(b.Return(func, utils::Vector{access})); + mod.functions.Push(func); + + auto* expect = R"( +%foo = func(%2:mat2x2<f32>, %3:i32):f32 -> %b1 { + %b1 = block { + %4:ptr<function, mat2x2<f32>, read_write> = var, %2 + %5:ptr<function, f32, read_write> = access %4, %3 + %6:f32 = load %5 + ret %6 + } +} +)"; + + Run<VarForDynamicIndex>(); + + EXPECT_EQ(expect, str()); +} + +TEST_F(IR_VarForDynamicIndexTest, AccessChain) { + auto* arr = b.FunctionParam(ty.array(ty.array(ty.array(ty.i32(), 4u), 4u), 4u)); + auto* idx = b.FunctionParam(ty.i32()); + auto* func = b.CreateFunction("foo", ty.i32()); + func->SetParams(utils::Vector{arr, idx}); + + auto* access = b.Access(ty.i32(), arr, utils::Vector{idx, b.Constant(1_u), idx}); + func->StartTarget()->Append(access); + func->StartTarget()->Append(b.Return(func, utils::Vector{access})); + mod.functions.Push(func); + + auto* expect = R"( +%foo = func(%2:array<array<array<i32, 4>, 4>, 4>, %3:i32):i32 -> %b1 { + %b1 = block { + %4:ptr<function, array<array<array<i32, 4>, 4>, 4>, read_write> = var, %2 + %5:ptr<function, i32, read_write> = access %4, %3, 1u, %3 + %6:i32 = load %5 + ret %6 + } +} +)"; + + Run<VarForDynamicIndex>(); + + EXPECT_EQ(expect, str()); +} + +TEST_F(IR_VarForDynamicIndexTest, AccessChain_SkipConstantIndices) { + auto* arr = b.FunctionParam(ty.array(ty.array(ty.array(ty.i32(), 4u), 4u), 4u)); + auto* idx = b.FunctionParam(ty.i32()); + auto* func = b.CreateFunction("foo", ty.i32()); + func->SetParams(utils::Vector{arr, idx}); + + auto* access = b.Access(ty.i32(), arr, utils::Vector{b.Constant(1_u), b.Constant(2_u), idx}); + func->StartTarget()->Append(access); + func->StartTarget()->Append(b.Return(func, utils::Vector{access})); + mod.functions.Push(func); + + auto* expect = R"( +%foo = func(%2:array<array<array<i32, 4>, 4>, 4>, %3:i32):i32 -> %b1 { + %b1 = block { + %4:array<i32, 4> = access %2, 1u, 2u + %5:ptr<function, array<i32, 4>, read_write> = var, %4 + %6:ptr<function, i32, read_write> = access %5, %3 + %7:i32 = load %6 + ret %7 + } +} +)"; + + Run<VarForDynamicIndex>(); + + EXPECT_EQ(expect, str()); +} + +TEST_F(IR_VarForDynamicIndexTest, AccessChain_SkipConstantIndices_Interleaved) { + auto* arr = b.FunctionParam(ty.array(ty.array(ty.array(ty.array(ty.i32(), 4u), 4u), 4u), 4u)); + auto* idx = b.FunctionParam(ty.i32()); + auto* func = b.CreateFunction("foo", ty.i32()); + func->SetParams(utils::Vector{arr, idx}); + + auto* access = + b.Access(ty.i32(), arr, utils::Vector{b.Constant(1_u), idx, b.Constant(2_u), idx}); + func->StartTarget()->Append(access); + func->StartTarget()->Append(b.Return(func, utils::Vector{access})); + mod.functions.Push(func); + + auto* expect = R"( +%foo = func(%2:array<array<array<array<i32, 4>, 4>, 4>, 4>, %3:i32):i32 -> %b1 { + %b1 = block { + %4:array<array<array<i32, 4>, 4>, 4> = access %2, 1u + %5:ptr<function, array<array<array<i32, 4>, 4>, 4>, read_write> = var, %4 + %6:ptr<function, i32, read_write> = access %5, %3, 2u, %3 + %7:i32 = load %6 + ret %7 + } +} +)"; + + Run<VarForDynamicIndex>(); + + EXPECT_EQ(expect, str()); +} + +TEST_F(IR_VarForDynamicIndexTest, AccessChain_SkipConstantIndices_Struct) { + auto* str_ty = ty.Get<type::Struct>( + mod.symbols.Register("MyStruct"), + utils::Vector{ + ty.Get<type::StructMember>(mod.symbols.Register("arr1"), ty.array(ty.f32(), 1024u), 0u, + 0u, 4u, 4096u, type::StructMemberAttributes{}), + ty.Get<type::StructMember>(mod.symbols.Register("mat"), ty.mat4x4(ty.f32()), 1u, 4096u, + 16u, 64u, type::StructMemberAttributes{}), + ty.Get<type::StructMember>(mod.symbols.Register("arr2"), ty.array(ty.f32(), 1024u), 2u, + 4160u, 4u, 4096u, type::StructMemberAttributes{}), + }, + 16u, 32u, 32u); + auto* str_val = b.FunctionParam(str_ty); + auto* idx = b.FunctionParam(ty.i32()); + auto* func = b.CreateFunction("foo", ty.f32()); + func->SetParams(utils::Vector{str_val, idx}); + + auto* access = + b.Access(ty.f32(), str_val, utils::Vector{b.Constant(1_u), idx, b.Constant(0_u)}); + func->StartTarget()->Append(access); + func->StartTarget()->Append(b.Return(func, utils::Vector{access})); + mod.functions.Push(func); + + auto* expect = R"( +MyStruct = struct @align(16) { + arr1:array<f32, 1024> @offset(0) + mat:mat4x4<f32> @offset(4096) + arr2:array<f32, 1024> @offset(4160) +} + +%foo = func(%2:MyStruct, %3:i32):f32 -> %b1 { + %b1 = block { + %4:mat4x4<f32> = access %2, 1u + %5:ptr<function, mat4x4<f32>, read_write> = var, %4 + %6:ptr<function, f32, read_write> = access %5, %3, 0u + %7:f32 = load %6 + ret %7 + } +} +)"; + + Run<VarForDynamicIndex>(); + + EXPECT_EQ(expect, str()); +} + +TEST_F(IR_VarForDynamicIndexTest, MultipleAccessesFromSameSource) { + auto* arr = b.FunctionParam(ty.array(ty.i32(), 4u)); + auto* idx_a = b.FunctionParam(ty.i32()); + auto* idx_b = b.FunctionParam(ty.i32()); + auto* idx_c = b.FunctionParam(ty.i32()); + auto* func = b.CreateFunction("foo", ty.i32()); + func->SetParams(utils::Vector{arr, idx_a, idx_b, idx_c}); + + auto* access_a = b.Access(ty.i32(), arr, utils::Vector{idx_a}); + auto* access_b = b.Access(ty.i32(), arr, utils::Vector{idx_b}); + auto* access_c = b.Access(ty.i32(), arr, utils::Vector{idx_c}); + func->StartTarget()->Append(access_a); + func->StartTarget()->Append(access_b); + func->StartTarget()->Append(access_c); + func->StartTarget()->Append(b.Return(func, utils::Vector{access_c})); + mod.functions.Push(func); + + auto* expect = R"( +%foo = func(%2:array<i32, 4>, %3:i32, %4:i32, %5:i32):i32 -> %b1 { + %b1 = block { + %6:ptr<function, array<i32, 4>, read_write> = var, %2 + %7:ptr<function, i32, read_write> = access %6, %3 + %8:i32 = load %7 + %9:ptr<function, i32, read_write> = access %6, %4 + %10:i32 = load %9 + %11:ptr<function, i32, read_write> = access %6, %5 + %12:i32 = load %11 + ret %12 + } +} +)"; + + Run<VarForDynamicIndex>(); + + EXPECT_EQ(expect, str()); +} + +TEST_F(IR_VarForDynamicIndexTest, MultipleAccessesFromSameSource_SkipConstantIndices) { + auto* arr = b.FunctionParam(ty.array(ty.array(ty.array(ty.i32(), 4u), 4u), 4u)); + auto* idx_a = b.FunctionParam(ty.i32()); + auto* idx_b = b.FunctionParam(ty.i32()); + auto* idx_c = b.FunctionParam(ty.i32()); + auto* func = b.CreateFunction("foo", ty.i32()); + func->SetParams(utils::Vector{arr, idx_a, idx_b, idx_c}); + + auto* access_a = + b.Access(ty.i32(), arr, utils::Vector{b.Constant(1_u), b.Constant(2_u), idx_a}); + auto* access_b = + b.Access(ty.i32(), arr, utils::Vector{b.Constant(1_u), b.Constant(2_u), idx_b}); + auto* access_c = + b.Access(ty.i32(), arr, utils::Vector{b.Constant(1_u), b.Constant(2_u), idx_c}); + func->StartTarget()->Append(access_a); + func->StartTarget()->Append(access_b); + func->StartTarget()->Append(access_c); + func->StartTarget()->Append(b.Return(func, utils::Vector{access_c})); + mod.functions.Push(func); + + auto* expect = R"( +%foo = func(%2:array<array<array<i32, 4>, 4>, 4>, %3:i32, %4:i32, %5:i32):i32 -> %b1 { + %b1 = block { + %6:array<i32, 4> = access %2, 1u, 2u + %7:ptr<function, array<i32, 4>, read_write> = var, %6 + %8:ptr<function, i32, read_write> = access %7, %3 + %9:i32 = load %8 + %10:ptr<function, i32, read_write> = access %7, %4 + %11:i32 = load %10 + %12:ptr<function, i32, read_write> = access %7, %5 + %13:i32 = load %12 + ret %13 + } +} +)"; + + Run<VarForDynamicIndex>(); + + EXPECT_EQ(expect, str()); +} + +} // namespace +} // namespace tint::ir::transform
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc index 773c96f..a173b49 100644 --- a/src/tint/writer/spirv/ir/generator_impl_ir.cc +++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc
@@ -38,6 +38,7 @@ #include "src/tint/ir/switch.h" #include "src/tint/ir/transform/add_empty_entry_point.h" #include "src/tint/ir/transform/block_decorated_structs.h" +#include "src/tint/ir/transform/var_for_dynamic_index.h" #include "src/tint/ir/user_call.h" #include "src/tint/ir/validate.h" #include "src/tint/ir/var.h" @@ -68,6 +69,7 @@ manager.Add<ir::transform::AddEmptyEntryPoint>(); manager.Add<ir::transform::BlockDecoratedStructs>(); + manager.Add<ir::transform::VarForDynamicIndex>(); transform::DataMap outputs; manager.Run(module, data, outputs); @@ -566,7 +568,6 @@ // For non-pointer types, we assume that the indices are constants and use OpCompositeExtract. // If we hit a non-constant index into a vector type, use OpVectorExtractDynamic for it. - // TODO(jrprice): Port VarForDynamicIndex transform to IR to make the above assertion true. auto* ty = access->Object()->Type(); for (auto* idx : access->Indices()) { if (auto* constant = idx->As<ir::Constant>()) {