Implement clamping of runtime array accesses

Bug: tint:252
Change-Id: I2b32ab9d69ca39b6178fc4e94ccd090516a37c98
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/36620
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/transform/bound_array_accessors.cc b/src/transform/bound_array_accessors.cc
index 4543a17..c7ce85e 100644
--- a/src/transform/bound_array_accessors.cc
+++ b/src/transform/bound_array_accessors.cc
@@ -14,6 +14,7 @@
 
 #include "src/transform/bound_array_accessors.h"
 
+#include <algorithm>
 #include <memory>
 #include <utility>
 
@@ -22,6 +23,7 @@
 #include "src/ast/bitcast_expression.h"
 #include "src/ast/block_statement.h"
 #include "src/ast/break_statement.h"
+#include "src/ast/builder.h"
 #include "src/ast/call_expression.h"
 #include "src/ast/call_statement.h"
 #include "src/ast/case_statement.h"
@@ -74,47 +76,47 @@
     return nullptr;
   }
 
-  uint32_t size = 0;
-  if (ret_type->Is<ast::type::Vector>() || ret_type->Is<ast::type::Array>()) {
-    size = ret_type->Is<ast::type::Vector>()
-               ? ret_type->As<ast::type::Vector>()->size()
-               : ret_type->As<ast::type::Array>()->size();
-    if (size == 0) {
-      diag::Diagnostic err;
-      err.severity = diag::Severity::Error;
-      err.message = "invalid 0 size for array or vector";
-      err.source = expr->source();
-      diags->add(std::move(err));
-      return nullptr;
-    }
+  ast::Builder b(ctx->mod);
+  using u32 = ast::Builder::u32;
 
+  uint32_t size = 0;
+  bool is_vec = ret_type->Is<ast::type::Vector>();
+  bool is_arr = ret_type->Is<ast::type::Array>();
+  if (is_vec || is_arr) {
+    size = is_vec ? ret_type->As<ast::type::Vector>()->size()
+                  : ret_type->As<ast::type::Array>()->size();
   } else {
     // The row accessor would have been an embedded array accessor and already
     // handled, so we just need to do columns here.
     size = ret_type->As<ast::type::Matrix>()->columns();
   }
 
-  ast::Expression* idx_expr = nullptr;
+  auto* const old_idx = expr->idx_expr();
+  b.SetSource(ctx->Clone(old_idx->source()));
 
-  // Scalar constructor we can re-write the value to be within bounds.
-  if (auto* c = expr->idx_expr()->As<ast::ScalarConstructorExpression>()) {
+  ast::Expression* new_idx = nullptr;
+
+  if (size == 0) {
+    if (is_arr) {
+      auto* arr_len = b.Call("arrayLength", ctx->Clone(expr->array()));
+      auto* limit = b.Sub(arr_len, b.Expr(1u));
+      new_idx = b.Call("min", b.Construct<u32>(ctx->Clone(old_idx)), limit);
+    } else {
+      diag::Diagnostic err;
+      err.severity = diag::Severity::Error;
+      err.message = "invalid 0 size";
+      err.source = expr->source();
+      diags->add(std::move(err));
+      return nullptr;
+    }
+  } else if (auto* c = old_idx->As<ast::ScalarConstructorExpression>()) {
+    // Scalar constructor we can re-write the value to be within bounds.
     auto* lit = c->literal();
     if (auto* sint = lit->As<ast::SintLiteral>()) {
-      int32_t val = sint->value();
-      if (val < 0) {
-        val = 0;
-      } else if (val >= int32_t(size)) {
-        val = int32_t(size) - 1;
-      }
-      lit = ctx->mod->create<ast::SintLiteral>(ctx->Clone(sint->source()),
-                                               ctx->Clone(sint->type()), val);
+      int32_t max = static_cast<int32_t>(size) - 1;
+      new_idx = b.Expr(std::max(std::min(sint->value(), max), 0));
     } else if (auto* uint = lit->As<ast::UintLiteral>()) {
-      uint32_t val = uint->value();
-      if (val >= size - 1) {
-        val = size - 1;
-      }
-      lit = ctx->mod->create<ast::UintLiteral>(ctx->Clone(uint->source()),
-                                               ctx->Clone(uint->type()), val);
+      new_idx = b.Expr(std::min(uint->value(), size - 1));
     } else {
       diag::Diagnostic err;
       err.severity = diag::Severity::Error;
@@ -123,32 +125,13 @@
       diags->add(std::move(err));
       return nullptr;
     }
-    idx_expr =
-        ctx->mod->create<ast::ScalarConstructorExpression>(c->source(), lit);
   } else {
-    auto* u32 = ctx->mod->create<ast::type::U32>();
-
-    ast::ExpressionList cast_expr;
-    cast_expr.push_back(ctx->Clone(expr->idx_expr()));
-
-    ast::ExpressionList params;
-    params.push_back(ctx->mod->create<ast::TypeConstructorExpression>(
-        Source{}, u32, cast_expr));
-    params.push_back(ctx->mod->create<ast::ScalarConstructorExpression>(
-        Source{}, ctx->mod->create<ast::UintLiteral>(Source{}, u32, size - 1)));
-
-    auto* call_expr = ctx->mod->create<ast::CallExpression>(
-        Source{},
-        ctx->mod->create<ast::IdentifierExpression>(
-            Source{}, ctx->mod->RegisterSymbol("min"), "min"),
-        std::move(params));
-    call_expr->set_result_type(u32);
-
-    idx_expr = call_expr;
+    new_idx =
+        b.Call("min", b.Construct<u32>(ctx->Clone(old_idx)), b.Expr(size - 1));
   }
 
-  return ctx->mod->create<ast::ArrayAccessorExpression>(
-      ctx->Clone(expr->source()), ctx->Clone(expr->array()), idx_expr);
+  return b.create<ast::ArrayAccessorExpression>(
+      ctx->Clone(expr->source()), ctx->Clone(expr->array()), new_idx);
 }
 
 }  // namespace transform
diff --git a/src/transform/bound_array_accessors_test.cc b/src/transform/bound_array_accessors_test.cc
index e7d70ce..14ce062 100644
--- a/src/transform/bound_array_accessors_test.cc
+++ b/src/transform/bound_array_accessors_test.cc
@@ -446,16 +446,35 @@
   // -> var b : f32 = a[1][min(u32(idx), 0, 1)]
 }
 
-// TODO(dsinclair): Implement when we have arrayLength for Runtime Arrays
-TEST_F(BoundArrayAccessorsTest, DISABLED_RuntimeArray_Clamps) {
-  // struct S {
-  //   a : f32;
-  //   b : array<f32>;
-  // }
-  // S s;
-  // var b : f32 = s.b[25]
-  //
-  // -> var b : f32 = s.b[min(u32(25), arrayLength(s.b))]
+TEST_F(BoundArrayAccessorsTest, RuntimeArray_Clamps) {
+  auto* src = R"(
+struct S {
+  a : f32;
+  b : array<f32>;
+};
+var s : S;
+
+fn f() -> void {
+  var d : f32 = s.b[25];
+}
+)";
+
+  auto* expect = R"(
+struct S {
+  a : f32;
+  b : array<f32>;
+};
+
+var s : S;
+
+fn f() -> void {
+  var d : f32 = s.b[min(u32(25), (arrayLength(s.b) - 1u))];
+}
+)";
+
+  auto got = Transform<BoundArrayAccessors>(src);
+
+  EXPECT_EQ(expect, got);
 }
 
 // TODO(dsinclair): Clamp atomics when available.