Add transform::CalculateArrayLength
Used to used to replace calls to arrayLength() with a value calculated from the size of the storage buffer.
Bug: tint:185
Change-Id: If7ddc8dad2ed3d20c1d76b5f48bfd2c7634f96e2
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46877
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/BUILD.gn b/src/BUILD.gn
index 7b0aa75..8a7dc89 100644
--- a/src/BUILD.gn
+++ b/src/BUILD.gn
@@ -403,6 +403,8 @@
"transform/binding_remapper.h",
"transform/bound_array_accessors.cc",
"transform/bound_array_accessors.h",
+ "transform/calculate_array_length.cc",
+ "transform/calculate_array_length.h",
"transform/canonicalize_entry_point_io.cc",
"transform/canonicalize_entry_point_io.h",
"transform/decompose_storage_access.cc",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 1542687..b920e8c 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -218,6 +218,8 @@
transform/binding_remapper.h
transform/bound_array_accessors.cc
transform/bound_array_accessors.h
+ transform/calculate_array_length.cc
+ transform/calculate_array_length.h
transform/canonicalize_entry_point_io.cc
transform/canonicalize_entry_point_io.h
transform/decompose_storage_access.cc
@@ -731,6 +733,7 @@
list(APPEND TINT_TEST_SRCS
transform/binding_remapper_test.cc
transform/bound_array_accessors_test.cc
+ transform/calculate_array_length_test.cc
transform/canonicalize_entry_point_io_test.cc
transform/decompose_storage_access_test.cc
transform/emit_vertex_point_size_test.cc
diff --git a/src/transform/calculate_array_length.cc b/src/transform/calculate_array_length.cc
new file mode 100644
index 0000000..3d57d9c
--- /dev/null
+++ b/src/transform/calculate_array_length.cc
@@ -0,0 +1,236 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/transform/calculate_array_length.h"
+
+#include <unordered_map>
+#include <utility>
+
+#include "src/ast/call_statement.h"
+#include "src/program_builder.h"
+#include "src/semantic/call.h"
+#include "src/semantic/statement.h"
+#include "src/semantic/struct.h"
+#include "src/semantic/variable.h"
+#include "src/utils/get_or_create.h"
+#include "src/utils/hash.h"
+
+TINT_INSTANTIATE_TYPEINFO(
+ tint::transform::CalculateArrayLength::BufferSizeIntrinsic);
+
+namespace tint {
+namespace transform {
+
+namespace {
+
+/// ArrayUsage describes a runtime array usage.
+/// It is used as a key by the array_length_by_usage map.
+struct ArrayUsage {
+ ast::BlockStatement const* const block;
+ semantic::Node const* const buffer;
+ bool operator==(const ArrayUsage& rhs) const {
+ return block == rhs.block && buffer == rhs.buffer;
+ }
+ struct Hasher {
+ inline std::size_t operator()(const ArrayUsage& u) const {
+ return utils::Hash(u.block, u.buffer);
+ }
+ };
+};
+
+} // namespace
+
+CalculateArrayLength::BufferSizeIntrinsic::BufferSizeIntrinsic() = default;
+CalculateArrayLength::BufferSizeIntrinsic::~BufferSizeIntrinsic() = default;
+std::string CalculateArrayLength::BufferSizeIntrinsic::Name() const {
+ return "intrinsic_buffer_size";
+}
+
+CalculateArrayLength::BufferSizeIntrinsic*
+CalculateArrayLength::BufferSizeIntrinsic::Clone(CloneContext* ctx) const {
+ return ctx->dst->ASTNodes()
+ .Create<CalculateArrayLength::BufferSizeIntrinsic>();
+}
+
+CalculateArrayLength::CalculateArrayLength() = default;
+CalculateArrayLength::~CalculateArrayLength() = default;
+
+Transform::Output CalculateArrayLength::Run(const Program* in, const DataMap&) {
+ ProgramBuilder out;
+ CloneContext ctx(&out, in);
+
+ auto& sem = ctx.src->Sem();
+
+ // get_buffer_size_intrinsic() emits the function decorated with
+ // BufferSizeIntrinsic that is transformed by the HLSL writer into a call to
+ // [RW]ByteAddressBuffer.GetDimensions().
+ std::unordered_map<type::Struct*, Symbol> buffer_size_intrinsics;
+ auto get_buffer_size_intrinsic = [&](type::Struct* buffer_type) {
+ return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] {
+ auto name = ctx.dst->Symbols().New();
+ auto* func = ctx.dst->create<ast::Function>(
+ name,
+ ast::VariableList{
+ // Note: The buffer parameter requires the kStorage StorageClass
+ // in order for HLSL to emit this as a ByteAddressBuffer.
+ ctx.dst->create<ast::Variable>(
+ ctx.dst->Sym("buffer"), ast::StorageClass::kStorage,
+ ctx.Clone(buffer_type), true, nullptr, ast::DecorationList{}),
+ ctx.dst->Param("result",
+ ctx.dst->ty.pointer(ctx.dst->ty.u32(),
+ ast::StorageClass::kFunction)),
+ },
+ ctx.dst->ty.void_(), nullptr,
+ ast::DecorationList{
+ ctx.dst->ASTNodes().Create<BufferSizeIntrinsic>(),
+ },
+ ast::DecorationList{});
+ ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), buffer_type, func);
+ return name;
+ });
+ };
+
+ std::unordered_map<ArrayUsage, Symbol, ArrayUsage::Hasher>
+ array_length_by_usage;
+
+ // Find all the arrayLength() calls...
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ if (auto* call_expr = node->As<ast::CallExpression>()) {
+ auto* call = sem.Get(call_expr);
+ if (auto* intrinsic = call->Target()->As<semantic::Intrinsic>()) {
+ if (intrinsic->Type() == semantic::IntrinsicType::kArrayLength) {
+ // We're dealing with an arrayLength() call
+
+ // https://gpuweb.github.io/gpuweb/wgsl.html#array-types states:
+ //
+ // * The last member of the structure type defining the store type for
+ // a variable in the storage storage class may be a runtime-sized
+ // array.
+ // * A runtime-sized array must not be used as the store type or
+ // contained within a store type in any other cases.
+ // * The type of an expression must not be a runtime-sized array type.
+ // arrayLength()
+ //
+ // We can assume that the arrayLength() call has a single argument of
+ // the form: arrayLength(X.Y) where X is an expression that resolves
+ // to the storage buffer structure, and Y is the runtime sized array.
+ auto* array_expr = call_expr->params()[0];
+ auto* accessor = array_expr->As<ast::MemberAccessorExpression>();
+ if (!accessor) {
+ TINT_ICE(ctx.dst->Diagnostics())
+ << "arrayLength() expected ast::MemberAccessorExpression, got "
+ << array_expr->TypeInfo().name;
+ break;
+ }
+ auto* storage_buffer_expr = accessor->structure();
+ auto* storage_buffer_sem = sem.Get(storage_buffer_expr);
+ auto* storage_buffer_type =
+ storage_buffer_sem->Type()->UnwrapAll()->As<type::Struct>();
+
+ // Generate BufferSizeIntrinsic for this storage type if we haven't
+ // already
+ auto buffer_size = get_buffer_size_intrinsic(storage_buffer_type);
+
+ if (!storage_buffer_type) {
+ TINT_ICE(ctx.dst->Diagnostics())
+ << "arrayLength(X.Y) expected X to be type::Struct, got "
+ << storage_buffer_type->FriendlyName(ctx.src->Symbols());
+ break;
+ }
+
+ // Find the current statement block
+ auto* block = call->Stmt()->Block();
+ if (!block) {
+ TINT_ICE(ctx.dst->Diagnostics())
+ << "arrayLength() statement is outside a BlockStatement";
+ break;
+ }
+
+ // If the storage_buffer_expr is resolves to a variable (typically
+ // true) then key the array_length from the variable. If not, key off
+ // the expression semantic node, which will be unique per call to
+ // arrayLength().
+ const semantic::Node* storage_buffer_usage = storage_buffer_sem;
+ if (auto* user = storage_buffer_sem->As<semantic::VariableUser>()) {
+ storage_buffer_usage = user->Variable();
+ }
+
+ auto array_length = utils::GetOrCreate(
+ array_length_by_usage, {block, storage_buffer_usage}, [&] {
+ // First time this array length is used for this block.
+ // Let's calculate it.
+
+ // Semantic info for the storage buffer structure
+ auto* storage_buffer_type_sem =
+ ctx.src->Sem().Get(storage_buffer_type);
+ // Semantic info for the runtime array structure member
+ auto* array_member_sem =
+ storage_buffer_type_sem->Members().back();
+
+ // Construct the variable that'll hold the result of
+ // RWByteAddressBuffer.GetDimensions()
+ auto* buffer_size_result =
+ ctx.dst->create<ast::VariableDeclStatement>(ctx.dst->Var(
+ ctx.dst->Symbols().New(), ctx.dst->ty.u32(),
+ ast::StorageClass::kFunction, ctx.dst->Expr(0u)));
+
+ // Call storage_buffer.GetDimensions(buffer_size_result)
+ auto* call_get_dims =
+ ctx.dst->create<ast::CallStatement>(ctx.dst->Call(
+ // BufferSizeIntrinsic(X, ARGS...) is
+ // translated to:
+ // X.GetDimensions(ARGS..) by the writer
+ buffer_size, ctx.Clone(storage_buffer_expr),
+ buffer_size_result->variable()->symbol()));
+
+ // Calculate actual array length
+ // total_storage_buffer_size - array_offset
+ // array_length = ----------------------------------------
+ // array_stride
+ auto name = ctx.dst->Symbols().New();
+ uint32_t array_offset = array_member_sem->Offset();
+ uint32_t array_stride = array_member_sem->Size();
+ auto* array_length_var =
+ ctx.dst->create<ast::VariableDeclStatement>(ctx.dst->Const(
+ name, ctx.dst->ty.u32(),
+ ctx.dst->Div(
+ ctx.dst->Sub(
+ buffer_size_result->variable()->symbol(),
+ array_offset),
+ array_stride)));
+
+ // Insert the array length calculations at the top of the block
+ ctx.InsertBefore(block->statements(), *block->begin(),
+ buffer_size_result);
+ ctx.InsertBefore(block->statements(), *block->begin(),
+ call_get_dims);
+ ctx.InsertBefore(block->statements(), *block->begin(),
+ array_length_var);
+ return name;
+ });
+
+ // Replace the call to arrayLength() with the array length variable
+ ctx.Replace(call_expr, ctx.dst->Expr(array_length));
+ }
+ }
+ }
+ }
+
+ ctx.Clone();
+
+ return Output{Program(std::move(out))};
+}
+
+} // namespace transform
+} // namespace tint
diff --git a/src/transform/calculate_array_length.h b/src/transform/calculate_array_length.h
new file mode 100644
index 0000000..fa96d81
--- /dev/null
+++ b/src/transform/calculate_array_length.h
@@ -0,0 +1,68 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TRANSFORM_CALCULATE_ARRAY_LENGTH_H_
+#define SRC_TRANSFORM_CALCULATE_ARRAY_LENGTH_H_
+
+#include <string>
+
+#include "src/ast/internal_decoration.h"
+#include "src/transform/transform.h"
+
+namespace tint {
+
+// Forward declarations
+class CloneContext;
+
+namespace transform {
+
+/// CalculateArrayLength is a transform used to replace calls to arrayLength()
+/// with a value calculated from the size of the storage buffer.
+class CalculateArrayLength : public Transform {
+ public:
+ /// BufferSizeIntrinsic is an InternalDecoration that's applied to intrinsic
+ /// functions used to obtain the runtime size of a storage buffer.
+ class BufferSizeIntrinsic
+ : public Castable<BufferSizeIntrinsic, ast::InternalDecoration> {
+ public:
+ /// Constructor
+ BufferSizeIntrinsic();
+ /// Destructor
+ ~BufferSizeIntrinsic() override;
+
+ /// @return "buffer_size"
+ std::string Name() const override;
+
+ /// Performs a deep clone of this object using the CloneContext `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned object
+ BufferSizeIntrinsic* Clone(CloneContext* ctx) const override;
+ };
+
+ /// Constructor
+ CalculateArrayLength();
+ /// Destructor
+ ~CalculateArrayLength() override;
+
+ /// Runs the transform on `program`, returning the transformation result.
+ /// @param program the source program to transform
+ /// @param data optional extra transform-specific data
+ /// @returns the transformation result
+ Output Run(const Program* program, const DataMap& data = {}) override;
+};
+
+} // namespace transform
+} // namespace tint
+
+#endif // SRC_TRANSFORM_CALCULATE_ARRAY_LENGTH_H_
diff --git a/src/transform/calculate_array_length_test.cc b/src/transform/calculate_array_length_test.cc
new file mode 100644
index 0000000..42b2f10
--- /dev/null
+++ b/src/transform/calculate_array_length_test.cc
@@ -0,0 +1,284 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/transform/calculate_array_length.h"
+
+#include "src/transform/test_helper.h"
+
+namespace tint {
+namespace transform {
+namespace {
+
+using CalculateArrayLengthTest = TransformTest;
+
+TEST_F(CalculateArrayLengthTest, Basic) {
+ auto* src = R"(
+[[block]]
+struct SB {
+ x : i32;
+ arr : array<i32>;
+};
+
+var<storage> sb : SB;
+
+[[stage(vertex)]]
+fn main() {
+ var len : u32 = arrayLength(sb.arr);
+}
+)";
+
+ auto* expect = R"(
+[[block]]
+struct SB {
+ x : i32;
+ arr : array<i32>;
+};
+
+[[internal(intrinsic_buffer_size)]]
+fn tint_symbol_1(buffer : SB, result : ptr<function, u32>)
+
+var<storage> sb : SB;
+
+[[stage(vertex)]]
+fn main() {
+ var tint_symbol_7 : u32 = 0u;
+ tint_symbol_1(sb, tint_symbol_7);
+ let tint_symbol_9 : u32 = ((tint_symbol_7 - 4u) / 4u);
+ var len : u32 = tint_symbol_9;
+}
+)";
+
+ auto got = Run<CalculateArrayLength>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CalculateArrayLengthTest, InSameBlock) {
+ auto* src = R"(
+[[block]]
+struct SB {
+ x : i32;
+ arr : array<i32>;
+};
+
+var<storage> sb : SB;
+
+[[stage(vertex)]]
+fn main() {
+ var a : u32 = arrayLength(sb.arr);
+ var b : u32 = arrayLength(sb.arr);
+ var c : u32 = arrayLength(sb.arr);
+}
+)";
+
+ auto* expect = R"(
+[[block]]
+struct SB {
+ x : i32;
+ arr : array<i32>;
+};
+
+[[internal(intrinsic_buffer_size)]]
+fn tint_symbol_1(buffer : SB, result : ptr<function, u32>)
+
+var<storage> sb : SB;
+
+[[stage(vertex)]]
+fn main() {
+ var tint_symbol_7 : u32 = 0u;
+ tint_symbol_1(sb, tint_symbol_7);
+ let tint_symbol_9 : u32 = ((tint_symbol_7 - 4u) / 4u);
+ var a : u32 = tint_symbol_9;
+ var b : u32 = tint_symbol_9;
+ var c : u32 = tint_symbol_9;
+}
+)";
+
+ auto got = Run<CalculateArrayLength>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CalculateArrayLengthTest, WithStride) {
+ auto* src = R"(
+[[block]]
+struct SB {
+ x : i32;
+ y : f32;
+ arr : [[stride(64)]] array<i32>;
+};
+
+var<storage> sb : SB;
+
+[[stage(vertex)]]
+fn main() {
+ var len : u32 = arrayLength(sb.arr);
+}
+)";
+
+ auto* expect = R"(
+[[block]]
+struct SB {
+ x : i32;
+ y : f32;
+ arr : [[stride(64)]] array<i32>;
+};
+
+[[internal(intrinsic_buffer_size)]]
+fn tint_symbol_1(buffer : SB, result : ptr<function, u32>)
+
+var<storage> sb : SB;
+
+[[stage(vertex)]]
+fn main() {
+ var tint_symbol_8 : u32 = 0u;
+ tint_symbol_1(sb, tint_symbol_8);
+ let tint_symbol_10 : u32 = ((tint_symbol_8 - 8u) / 64u);
+ var len : u32 = tint_symbol_10;
+}
+)";
+
+ auto got = Run<CalculateArrayLength>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CalculateArrayLengthTest, Nested) {
+ auto* src = R"(
+[[block]]
+struct SB {
+ x : i32;
+ arr : array<i32>;
+};
+
+var<storage> sb : SB;
+
+[[stage(vertex)]]
+fn main() {
+ if (true) {
+ var len : u32 = arrayLength(sb.arr);
+ } else {
+ if (true) {
+ var len : u32 = arrayLength(sb.arr);
+ }
+ }
+}
+)";
+
+ auto* expect = R"(
+[[block]]
+struct SB {
+ x : i32;
+ arr : array<i32>;
+};
+
+[[internal(intrinsic_buffer_size)]]
+fn tint_symbol_1(buffer : SB, result : ptr<function, u32>)
+
+var<storage> sb : SB;
+
+[[stage(vertex)]]
+fn main() {
+ if (true) {
+ var tint_symbol_7 : u32 = 0u;
+ tint_symbol_1(sb, tint_symbol_7);
+ let tint_symbol_9 : u32 = ((tint_symbol_7 - 4u) / 4u);
+ var len : u32 = tint_symbol_9;
+ } else {
+ if (true) {
+ var tint_symbol_10 : u32 = 0u;
+ tint_symbol_1(sb, tint_symbol_10);
+ let tint_symbol_11 : u32 = ((tint_symbol_10 - 4u) / 4u);
+ var len : u32 = tint_symbol_11;
+ }
+ }
+}
+)";
+
+ auto got = Run<CalculateArrayLength>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CalculateArrayLengthTest, MultipleStorageBuffers) {
+ auto* src = R"(
+[[block]]
+struct SB1 {
+ x : i32;
+ arr1 : array<i32>;
+};
+
+[[block]]
+struct SB2 {
+ x : i32;
+ arr2 : array<vec4<f32>>;
+};
+
+var<storage> sb1 : SB1;
+
+var<storage> sb2 : SB2;
+
+[[stage(vertex)]]
+fn main() {
+ var len1 : u32 = arrayLength(sb1.arr1);
+ var len2 : u32 = arrayLength(sb2.arr2);
+ var x : u32 = (len1 + len2);
+}
+)";
+
+ auto* expect = R"(
+[[block]]
+struct SB1 {
+ x : i32;
+ arr1 : array<i32>;
+};
+
+[[internal(intrinsic_buffer_size)]]
+fn tint_symbol_1(buffer : SB1, result : ptr<function, u32>)
+
+[[block]]
+struct SB2 {
+ x : i32;
+ arr2 : array<vec4<f32>>;
+};
+
+[[internal(intrinsic_buffer_size)]]
+fn tint_symbol_10(buffer : SB2, result : ptr<function, u32>)
+
+var<storage> sb1 : SB1;
+
+var<storage> sb2 : SB2;
+
+[[stage(vertex)]]
+fn main() {
+ var tint_symbol_7 : u32 = 0u;
+ tint_symbol_1(sb1, tint_symbol_7);
+ let tint_symbol_9 : u32 = ((tint_symbol_7 - 4u) / 4u);
+ var tint_symbol_13 : u32 = 0u;
+ tint_symbol_10(sb2, tint_symbol_13);
+ let tint_symbol_15 : u32 = ((tint_symbol_13 - 16u) / 16u);
+ var len1 : u32 = tint_symbol_9;
+ var len2 : u32 = tint_symbol_15;
+ var x : u32 = (len1 + len2);
+}
+)";
+
+ auto got = Run<CalculateArrayLength>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace transform
+} // namespace tint
diff --git a/test/BUILD.gn b/test/BUILD.gn
index ad3fe6e..6827f4b 100644
--- a/test/BUILD.gn
+++ b/test/BUILD.gn
@@ -194,6 +194,7 @@
"../src/traits_test.cc",
"../src/transform/binding_remapper_test.cc",
"../src/transform/bound_array_accessors_test.cc",
+ "../src/transform/calculate_array_length_test.cc",
"../src/transform/canonicalize_entry_point_io_test.cc",
"../src/transform/decompose_storage_access_test.cc",
"../src/transform/emit_vertex_point_size_test.cc",