Add transform::CalculateArrayLength

Used to used to replace calls to arrayLength() with a value calculated from the size of the storage buffer.

Bug: tint:185
Change-Id: If7ddc8dad2ed3d20c1d76b5f48bfd2c7634f96e2
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46877
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/BUILD.gn b/src/BUILD.gn
index 7b0aa75..8a7dc89 100644
--- a/src/BUILD.gn
+++ b/src/BUILD.gn
@@ -403,6 +403,8 @@
     "transform/binding_remapper.h",
     "transform/bound_array_accessors.cc",
     "transform/bound_array_accessors.h",
+    "transform/calculate_array_length.cc",
+    "transform/calculate_array_length.h",
     "transform/canonicalize_entry_point_io.cc",
     "transform/canonicalize_entry_point_io.h",
     "transform/decompose_storage_access.cc",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 1542687..b920e8c 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -218,6 +218,8 @@
   transform/binding_remapper.h
   transform/bound_array_accessors.cc
   transform/bound_array_accessors.h
+  transform/calculate_array_length.cc
+  transform/calculate_array_length.h
   transform/canonicalize_entry_point_io.cc
   transform/canonicalize_entry_point_io.h
   transform/decompose_storage_access.cc
@@ -731,6 +733,7 @@
     list(APPEND TINT_TEST_SRCS
       transform/binding_remapper_test.cc
       transform/bound_array_accessors_test.cc
+      transform/calculate_array_length_test.cc
       transform/canonicalize_entry_point_io_test.cc
       transform/decompose_storage_access_test.cc
       transform/emit_vertex_point_size_test.cc
diff --git a/src/transform/calculate_array_length.cc b/src/transform/calculate_array_length.cc
new file mode 100644
index 0000000..3d57d9c
--- /dev/null
+++ b/src/transform/calculate_array_length.cc
@@ -0,0 +1,236 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/transform/calculate_array_length.h"
+
+#include <unordered_map>
+#include <utility>
+
+#include "src/ast/call_statement.h"
+#include "src/program_builder.h"
+#include "src/semantic/call.h"
+#include "src/semantic/statement.h"
+#include "src/semantic/struct.h"
+#include "src/semantic/variable.h"
+#include "src/utils/get_or_create.h"
+#include "src/utils/hash.h"
+
+TINT_INSTANTIATE_TYPEINFO(
+    tint::transform::CalculateArrayLength::BufferSizeIntrinsic);
+
+namespace tint {
+namespace transform {
+
+namespace {
+
+/// ArrayUsage describes a runtime array usage.
+/// It is used as a key by the array_length_by_usage map.
+struct ArrayUsage {
+  ast::BlockStatement const* const block;
+  semantic::Node const* const buffer;
+  bool operator==(const ArrayUsage& rhs) const {
+    return block == rhs.block && buffer == rhs.buffer;
+  }
+  struct Hasher {
+    inline std::size_t operator()(const ArrayUsage& u) const {
+      return utils::Hash(u.block, u.buffer);
+    }
+  };
+};
+
+}  // namespace
+
+CalculateArrayLength::BufferSizeIntrinsic::BufferSizeIntrinsic() = default;
+CalculateArrayLength::BufferSizeIntrinsic::~BufferSizeIntrinsic() = default;
+std::string CalculateArrayLength::BufferSizeIntrinsic::Name() const {
+  return "intrinsic_buffer_size";
+}
+
+CalculateArrayLength::BufferSizeIntrinsic*
+CalculateArrayLength::BufferSizeIntrinsic::Clone(CloneContext* ctx) const {
+  return ctx->dst->ASTNodes()
+      .Create<CalculateArrayLength::BufferSizeIntrinsic>();
+}
+
+CalculateArrayLength::CalculateArrayLength() = default;
+CalculateArrayLength::~CalculateArrayLength() = default;
+
+Transform::Output CalculateArrayLength::Run(const Program* in, const DataMap&) {
+  ProgramBuilder out;
+  CloneContext ctx(&out, in);
+
+  auto& sem = ctx.src->Sem();
+
+  // get_buffer_size_intrinsic() emits the function decorated with
+  // BufferSizeIntrinsic that is transformed by the HLSL writer into a call to
+  // [RW]ByteAddressBuffer.GetDimensions().
+  std::unordered_map<type::Struct*, Symbol> buffer_size_intrinsics;
+  auto get_buffer_size_intrinsic = [&](type::Struct* buffer_type) {
+    return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] {
+      auto name = ctx.dst->Symbols().New();
+      auto* func = ctx.dst->create<ast::Function>(
+          name,
+          ast::VariableList{
+              // Note: The buffer parameter requires the kStorage StorageClass
+              // in order for HLSL to emit this as a ByteAddressBuffer.
+              ctx.dst->create<ast::Variable>(
+                  ctx.dst->Sym("buffer"), ast::StorageClass::kStorage,
+                  ctx.Clone(buffer_type), true, nullptr, ast::DecorationList{}),
+              ctx.dst->Param("result",
+                             ctx.dst->ty.pointer(ctx.dst->ty.u32(),
+                                                 ast::StorageClass::kFunction)),
+          },
+          ctx.dst->ty.void_(), nullptr,
+          ast::DecorationList{
+              ctx.dst->ASTNodes().Create<BufferSizeIntrinsic>(),
+          },
+          ast::DecorationList{});
+      ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), buffer_type, func);
+      return name;
+    });
+  };
+
+  std::unordered_map<ArrayUsage, Symbol, ArrayUsage::Hasher>
+      array_length_by_usage;
+
+  // Find all the arrayLength() calls...
+  for (auto* node : ctx.src->ASTNodes().Objects()) {
+    if (auto* call_expr = node->As<ast::CallExpression>()) {
+      auto* call = sem.Get(call_expr);
+      if (auto* intrinsic = call->Target()->As<semantic::Intrinsic>()) {
+        if (intrinsic->Type() == semantic::IntrinsicType::kArrayLength) {
+          // We're dealing with an arrayLength() call
+
+          // https://gpuweb.github.io/gpuweb/wgsl.html#array-types states:
+          //
+          // * The last member of the structure type defining the store type for
+          //   a variable in the storage storage class may be a runtime-sized
+          //   array.
+          // * A runtime-sized array must not be used as the store type or
+          //   contained within a store type in any other cases.
+          // * The type of an expression must not be a runtime-sized array type.
+          //   arrayLength()
+          //
+          // We can assume that the arrayLength() call has a single argument of
+          // the form: arrayLength(X.Y) where X is an expression that resolves
+          // to the storage buffer structure, and Y is the runtime sized array.
+          auto* array_expr = call_expr->params()[0];
+          auto* accessor = array_expr->As<ast::MemberAccessorExpression>();
+          if (!accessor) {
+            TINT_ICE(ctx.dst->Diagnostics())
+                << "arrayLength() expected ast::MemberAccessorExpression, got "
+                << array_expr->TypeInfo().name;
+            break;
+          }
+          auto* storage_buffer_expr = accessor->structure();
+          auto* storage_buffer_sem = sem.Get(storage_buffer_expr);
+          auto* storage_buffer_type =
+              storage_buffer_sem->Type()->UnwrapAll()->As<type::Struct>();
+
+          // Generate BufferSizeIntrinsic for this storage type if we haven't
+          // already
+          auto buffer_size = get_buffer_size_intrinsic(storage_buffer_type);
+
+          if (!storage_buffer_type) {
+            TINT_ICE(ctx.dst->Diagnostics())
+                << "arrayLength(X.Y) expected X to be type::Struct, got "
+                << storage_buffer_type->FriendlyName(ctx.src->Symbols());
+            break;
+          }
+
+          // Find the current statement block
+          auto* block = call->Stmt()->Block();
+          if (!block) {
+            TINT_ICE(ctx.dst->Diagnostics())
+                << "arrayLength() statement is outside a BlockStatement";
+            break;
+          }
+
+          // If the storage_buffer_expr is resolves to a variable (typically
+          // true) then key the array_length from the variable. If not, key off
+          // the expression semantic node, which will be unique per call to
+          // arrayLength().
+          const semantic::Node* storage_buffer_usage = storage_buffer_sem;
+          if (auto* user = storage_buffer_sem->As<semantic::VariableUser>()) {
+            storage_buffer_usage = user->Variable();
+          }
+
+          auto array_length = utils::GetOrCreate(
+              array_length_by_usage, {block, storage_buffer_usage}, [&] {
+                // First time this array length is used for this block.
+                // Let's calculate it.
+
+                // Semantic info for the storage buffer structure
+                auto* storage_buffer_type_sem =
+                    ctx.src->Sem().Get(storage_buffer_type);
+                // Semantic info for the runtime array structure member
+                auto* array_member_sem =
+                    storage_buffer_type_sem->Members().back();
+
+                // Construct the variable that'll hold the result of
+                // RWByteAddressBuffer.GetDimensions()
+                auto* buffer_size_result =
+                    ctx.dst->create<ast::VariableDeclStatement>(ctx.dst->Var(
+                        ctx.dst->Symbols().New(), ctx.dst->ty.u32(),
+                        ast::StorageClass::kFunction, ctx.dst->Expr(0u)));
+
+                // Call storage_buffer.GetDimensions(buffer_size_result)
+                auto* call_get_dims =
+                    ctx.dst->create<ast::CallStatement>(ctx.dst->Call(
+                        // BufferSizeIntrinsic(X, ARGS...) is
+                        // translated to:
+                        //  X.GetDimensions(ARGS..) by the writer
+                        buffer_size, ctx.Clone(storage_buffer_expr),
+                        buffer_size_result->variable()->symbol()));
+
+                // Calculate actual array length
+                //                total_storage_buffer_size - array_offset
+                // array_length = ----------------------------------------
+                //                             array_stride
+                auto name = ctx.dst->Symbols().New();
+                uint32_t array_offset = array_member_sem->Offset();
+                uint32_t array_stride = array_member_sem->Size();
+                auto* array_length_var =
+                    ctx.dst->create<ast::VariableDeclStatement>(ctx.dst->Const(
+                        name, ctx.dst->ty.u32(),
+                        ctx.dst->Div(
+                            ctx.dst->Sub(
+                                buffer_size_result->variable()->symbol(),
+                                array_offset),
+                            array_stride)));
+
+                // Insert the array length calculations at the top of the block
+                ctx.InsertBefore(block->statements(), *block->begin(),
+                                 buffer_size_result);
+                ctx.InsertBefore(block->statements(), *block->begin(),
+                                 call_get_dims);
+                ctx.InsertBefore(block->statements(), *block->begin(),
+                                 array_length_var);
+                return name;
+              });
+
+          // Replace the call to arrayLength() with the array length variable
+          ctx.Replace(call_expr, ctx.dst->Expr(array_length));
+        }
+      }
+    }
+  }
+
+  ctx.Clone();
+
+  return Output{Program(std::move(out))};
+}
+
+}  // namespace transform
+}  // namespace tint
diff --git a/src/transform/calculate_array_length.h b/src/transform/calculate_array_length.h
new file mode 100644
index 0000000..fa96d81
--- /dev/null
+++ b/src/transform/calculate_array_length.h
@@ -0,0 +1,68 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TRANSFORM_CALCULATE_ARRAY_LENGTH_H_
+#define SRC_TRANSFORM_CALCULATE_ARRAY_LENGTH_H_
+
+#include <string>
+
+#include "src/ast/internal_decoration.h"
+#include "src/transform/transform.h"
+
+namespace tint {
+
+// Forward declarations
+class CloneContext;
+
+namespace transform {
+
+/// CalculateArrayLength is a transform used to replace calls to arrayLength()
+/// with a value calculated from the size of the storage buffer.
+class CalculateArrayLength : public Transform {
+ public:
+  /// BufferSizeIntrinsic is an InternalDecoration that's applied to intrinsic
+  /// functions used to obtain the runtime size of a storage buffer.
+  class BufferSizeIntrinsic
+      : public Castable<BufferSizeIntrinsic, ast::InternalDecoration> {
+   public:
+    /// Constructor
+    BufferSizeIntrinsic();
+    /// Destructor
+    ~BufferSizeIntrinsic() override;
+
+    /// @return "buffer_size"
+    std::string Name() const override;
+
+    /// Performs a deep clone of this object using the CloneContext `ctx`.
+    /// @param ctx the clone context
+    /// @return the newly cloned object
+    BufferSizeIntrinsic* Clone(CloneContext* ctx) const override;
+  };
+
+  /// Constructor
+  CalculateArrayLength();
+  /// Destructor
+  ~CalculateArrayLength() override;
+
+  /// Runs the transform on `program`, returning the transformation result.
+  /// @param program the source program to transform
+  /// @param data optional extra transform-specific data
+  /// @returns the transformation result
+  Output Run(const Program* program, const DataMap& data = {}) override;
+};
+
+}  // namespace transform
+}  // namespace tint
+
+#endif  // SRC_TRANSFORM_CALCULATE_ARRAY_LENGTH_H_
diff --git a/src/transform/calculate_array_length_test.cc b/src/transform/calculate_array_length_test.cc
new file mode 100644
index 0000000..42b2f10
--- /dev/null
+++ b/src/transform/calculate_array_length_test.cc
@@ -0,0 +1,284 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/transform/calculate_array_length.h"
+
+#include "src/transform/test_helper.h"
+
+namespace tint {
+namespace transform {
+namespace {
+
+using CalculateArrayLengthTest = TransformTest;
+
+TEST_F(CalculateArrayLengthTest, Basic) {
+  auto* src = R"(
+[[block]]
+struct SB {
+  x : i32;
+  arr : array<i32>;
+};
+
+var<storage> sb : SB;
+
+[[stage(vertex)]]
+fn main() {
+  var len : u32 = arrayLength(sb.arr);
+}
+)";
+
+  auto* expect = R"(
+[[block]]
+struct SB {
+  x : i32;
+  arr : array<i32>;
+};
+
+[[internal(intrinsic_buffer_size)]]
+fn tint_symbol_1(buffer : SB, result : ptr<function, u32>)
+
+var<storage> sb : SB;
+
+[[stage(vertex)]]
+fn main() {
+  var tint_symbol_7 : u32 = 0u;
+  tint_symbol_1(sb, tint_symbol_7);
+  let tint_symbol_9 : u32 = ((tint_symbol_7 - 4u) / 4u);
+  var len : u32 = tint_symbol_9;
+}
+)";
+
+  auto got = Run<CalculateArrayLength>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CalculateArrayLengthTest, InSameBlock) {
+  auto* src = R"(
+[[block]]
+struct SB {
+  x : i32;
+  arr : array<i32>;
+};
+
+var<storage> sb : SB;
+
+[[stage(vertex)]]
+fn main() {
+  var a : u32 = arrayLength(sb.arr);
+  var b : u32 = arrayLength(sb.arr);
+  var c : u32 = arrayLength(sb.arr);
+}
+)";
+
+  auto* expect = R"(
+[[block]]
+struct SB {
+  x : i32;
+  arr : array<i32>;
+};
+
+[[internal(intrinsic_buffer_size)]]
+fn tint_symbol_1(buffer : SB, result : ptr<function, u32>)
+
+var<storage> sb : SB;
+
+[[stage(vertex)]]
+fn main() {
+  var tint_symbol_7 : u32 = 0u;
+  tint_symbol_1(sb, tint_symbol_7);
+  let tint_symbol_9 : u32 = ((tint_symbol_7 - 4u) / 4u);
+  var a : u32 = tint_symbol_9;
+  var b : u32 = tint_symbol_9;
+  var c : u32 = tint_symbol_9;
+}
+)";
+
+  auto got = Run<CalculateArrayLength>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CalculateArrayLengthTest, WithStride) {
+  auto* src = R"(
+[[block]]
+struct SB {
+  x : i32;
+  y : f32;
+  arr : [[stride(64)]] array<i32>;
+};
+
+var<storage> sb : SB;
+
+[[stage(vertex)]]
+fn main() {
+  var len : u32 = arrayLength(sb.arr);
+}
+)";
+
+  auto* expect = R"(
+[[block]]
+struct SB {
+  x : i32;
+  y : f32;
+  arr : [[stride(64)]] array<i32>;
+};
+
+[[internal(intrinsic_buffer_size)]]
+fn tint_symbol_1(buffer : SB, result : ptr<function, u32>)
+
+var<storage> sb : SB;
+
+[[stage(vertex)]]
+fn main() {
+  var tint_symbol_8 : u32 = 0u;
+  tint_symbol_1(sb, tint_symbol_8);
+  let tint_symbol_10 : u32 = ((tint_symbol_8 - 8u) / 64u);
+  var len : u32 = tint_symbol_10;
+}
+)";
+
+  auto got = Run<CalculateArrayLength>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CalculateArrayLengthTest, Nested) {
+  auto* src = R"(
+[[block]]
+struct SB {
+  x : i32;
+  arr : array<i32>;
+};
+
+var<storage> sb : SB;
+
+[[stage(vertex)]]
+fn main() {
+  if (true) {
+    var len : u32 = arrayLength(sb.arr);
+  } else {
+    if (true) {
+      var len : u32 = arrayLength(sb.arr);
+    }
+  }
+}
+)";
+
+  auto* expect = R"(
+[[block]]
+struct SB {
+  x : i32;
+  arr : array<i32>;
+};
+
+[[internal(intrinsic_buffer_size)]]
+fn tint_symbol_1(buffer : SB, result : ptr<function, u32>)
+
+var<storage> sb : SB;
+
+[[stage(vertex)]]
+fn main() {
+  if (true) {
+    var tint_symbol_7 : u32 = 0u;
+    tint_symbol_1(sb, tint_symbol_7);
+    let tint_symbol_9 : u32 = ((tint_symbol_7 - 4u) / 4u);
+    var len : u32 = tint_symbol_9;
+  } else {
+    if (true) {
+      var tint_symbol_10 : u32 = 0u;
+      tint_symbol_1(sb, tint_symbol_10);
+      let tint_symbol_11 : u32 = ((tint_symbol_10 - 4u) / 4u);
+      var len : u32 = tint_symbol_11;
+    }
+  }
+}
+)";
+
+  auto got = Run<CalculateArrayLength>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CalculateArrayLengthTest, MultipleStorageBuffers) {
+  auto* src = R"(
+[[block]]
+struct SB1 {
+  x : i32;
+  arr1 : array<i32>;
+};
+
+[[block]]
+struct SB2 {
+  x : i32;
+  arr2 : array<vec4<f32>>;
+};
+
+var<storage> sb1 : SB1;
+
+var<storage> sb2 : SB2;
+
+[[stage(vertex)]]
+fn main() {
+  var len1 : u32 = arrayLength(sb1.arr1);
+  var len2 : u32 = arrayLength(sb2.arr2);
+  var x : u32 = (len1 + len2);
+}
+)";
+
+  auto* expect = R"(
+[[block]]
+struct SB1 {
+  x : i32;
+  arr1 : array<i32>;
+};
+
+[[internal(intrinsic_buffer_size)]]
+fn tint_symbol_1(buffer : SB1, result : ptr<function, u32>)
+
+[[block]]
+struct SB2 {
+  x : i32;
+  arr2 : array<vec4<f32>>;
+};
+
+[[internal(intrinsic_buffer_size)]]
+fn tint_symbol_10(buffer : SB2, result : ptr<function, u32>)
+
+var<storage> sb1 : SB1;
+
+var<storage> sb2 : SB2;
+
+[[stage(vertex)]]
+fn main() {
+  var tint_symbol_7 : u32 = 0u;
+  tint_symbol_1(sb1, tint_symbol_7);
+  let tint_symbol_9 : u32 = ((tint_symbol_7 - 4u) / 4u);
+  var tint_symbol_13 : u32 = 0u;
+  tint_symbol_10(sb2, tint_symbol_13);
+  let tint_symbol_15 : u32 = ((tint_symbol_13 - 16u) / 16u);
+  var len1 : u32 = tint_symbol_9;
+  var len2 : u32 = tint_symbol_15;
+  var x : u32 = (len1 + len2);
+}
+)";
+
+  auto got = Run<CalculateArrayLength>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+}  // namespace
+}  // namespace transform
+}  // namespace tint
diff --git a/test/BUILD.gn b/test/BUILD.gn
index ad3fe6e..6827f4b 100644
--- a/test/BUILD.gn
+++ b/test/BUILD.gn
@@ -194,6 +194,7 @@
     "../src/traits_test.cc",
     "../src/transform/binding_remapper_test.cc",
     "../src/transform/bound_array_accessors_test.cc",
+    "../src/transform/calculate_array_length_test.cc",
     "../src/transform/canonicalize_entry_point_io_test.cc",
     "../src/transform/decompose_storage_access_test.cc",
     "../src/transform/emit_vertex_point_size_test.cc",