transform/Robustness: Re-work the accessor clamping

Rework the clamping so that it unifies the logic for arrays, matricies
and vectors. Try to preserve constant signess, and only clamp the values
if they're possibly out of bounds.

Use ConstantValue() instead of scanning for ScalarConstantExpressions.
As ConstantValue() improves, so will the performance of robustness.

Change-Id: I013a67e15f43350d0a57bcd8ba9ae0c1bcb1eaec
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/58280
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
Commit-Queue: David Neto <dneto@google.com>
diff --git a/src/transform/robustness.cc b/src/transform/robustness.cc
index 436f0a4..92abd78 100644
--- a/src/transform/robustness.cc
+++ b/src/transform/robustness.cc
@@ -15,6 +15,7 @@
 #include "src/transform/robustness.h"
 
 #include <algorithm>
+#include <limits>
 #include <utility>
 
 #include "src/program_builder.h"
@@ -46,64 +47,140 @@
   /// cloned without changes.
   ast::ArrayAccessorExpression* Transform(ast::ArrayAccessorExpression* expr) {
     auto* ret_type = ctx.src->Sem().Get(expr->array())->Type()->UnwrapRef();
-    if (!ret_type->IsAnyOf<sem::Array, sem::Matrix, sem::Vector>()) {
-      return nullptr;
-    }
 
     ProgramBuilder& b = *ctx.dst;
     using u32 = ProgramBuilder::u32;
 
-    uint32_t size = 0;
-    bool is_vec = ret_type->Is<sem::Vector>();
-    bool is_arr = ret_type->Is<sem::Array>();
-    if (is_vec || is_arr) {
-      size = is_vec ? ret_type->As<sem::Vector>()->size()
-                    : ret_type->As<sem::Array>()->Count();
-    } else {
+    struct Value {
+      ast::Expression* expr = nullptr;  // If null, then is a constant
+      union {
+        uint32_t u32 = 0;  // use if is_signed == false
+        int32_t i32;       // use if is_signed == true
+      };
+      bool is_signed = false;
+    };
+
+    Value size;              // size of the array, vector or matrix
+    size.is_signed = false;  // size is always unsigned
+    if (auto* vec = ret_type->As<sem::Vector>()) {
+      size.u32 = vec->size();
+
+    } else if (auto* arr = ret_type->As<sem::Array>()) {
+      size.u32 = arr->Count();
+    } else if (auto* mat = ret_type->As<sem::Matrix>()) {
       // The row accessor would have been an embedded array accessor and already
       // handled, so we just need to do columns here.
-      size = ret_type->As<sem::Matrix>()->columns();
+      size.u32 = mat->columns();
+    } else {
+      return nullptr;
     }
 
-    auto* const old_idx = expr->idx_expr();
-    b.SetSource(ctx.Clone(old_idx->source()));
-
-    ast::Expression* new_idx = nullptr;
-
-    if (size == 0) {
-      if (!is_arr) {
+    if (size.u32 == 0) {
+      if (!ret_type->Is<sem::Array>()) {
         b.Diagnostics().add_error(diag::System::Transform,
                                   "invalid 0 sized non-array", expr->source());
         return nullptr;
       }
       // Runtime sized array
       auto* arr = ctx.Clone(expr->array());
-      auto* arr_len = b.Call("arrayLength", b.AddressOf(arr));
-      auto* limit = b.Sub(arr_len, b.Expr(1u));
-      new_idx = b.Call("min", b.Construct<u32>(ctx.Clone(old_idx)), limit);
-    } else if (auto* c = old_idx->As<ast::ScalarConstructorExpression>()) {
-      // Scalar constructor we can re-write the value to be within bounds.
-      auto* lit = c->literal();
-      if (auto* sint = lit->As<ast::SintLiteral>()) {
-        int32_t max = static_cast<int32_t>(size) - 1;
-        new_idx = b.Expr(std::max(std::min(sint->value(), max), 0));
-      } else if (auto* uint = lit->As<ast::UintLiteral>()) {
-        new_idx = b.Expr(std::min(uint->value(), size - 1));
+      size.expr = b.Call("arrayLength", b.AddressOf(arr));
+    }
+
+    // Calculate the maximum possible index value (size-1u)
+    // Size must be positive (non-zero), so we can safely subtract 1 here
+    // without underflow.
+    Value limit;
+    limit.is_signed = false;  // Like size, limit is always unsigned.
+    if (size.expr) {
+      // Dynamic size
+      limit.expr = b.Sub(size.expr, 1u);
+    } else {
+      // Constant size
+      limit.u32 = size.u32 - 1u;
+    }
+
+    Value idx;  // index value
+
+    auto* idx_sem = ctx.src->Sem().Get(expr->idx_expr());
+    auto* idx_ty = idx_sem->Type()->UnwrapRef();
+    if (!idx_ty->IsAnyOf<sem::I32, sem::U32>()) {
+      TINT_ICE(Transform, b.Diagnostics())
+          << "index must be u32 or i32, got " << idx_sem->Type()->type_name();
+      return nullptr;
+    }
+
+    if (auto idx_constant = idx_sem->ConstantValue()) {
+      // Constant value index
+      if (idx_constant.Type()->Is<sem::I32>()) {
+        idx.i32 = idx_constant.Elements()[0].i32;
+        idx.is_signed = true;
+      } else if (idx_constant.Type()->Is<sem::U32>()) {
+        idx.u32 = idx_constant.Elements()[0].u32;
+        idx.is_signed = false;
       } else {
-        b.Diagnostics().add_error(
-            diag::System::Transform,
-            "unknown scalar constructor type for accessor", expr->source());
+        b.Diagnostics().add_error(diag::System::Transform,
+                                  "unsupported constant value for accessor: " +
+                                      idx_constant.Type()->type_name(),
+                                  expr->source());
         return nullptr;
       }
     } else {
-      auto* cloned_idx = ctx.Clone(old_idx);
-      new_idx = b.Call("min", b.Construct<u32>(cloned_idx), b.Expr(size - 1));
+      // Dynamic value index
+      idx.expr = ctx.Clone(expr->idx_expr());
+      idx.is_signed = idx_ty->Is<sem::I32>();
+    }
+
+    // Clamp the index so that it cannot exceed limit.
+    if (idx.expr || limit.expr) {
+      // One of, or both of idx and limit are non-constant.
+
+      // If the index is signed, cast it to a u32 (with clamping if constant).
+      if (idx.is_signed) {
+        if (idx.expr) {
+          // We don't use a max(idx, 0) here, as that incurs a runtime
+          // performance cost, and if the unsigned value will be clamped by
+          // limit, resulting in a value between [0..limit)
+          idx.expr = b.Construct<u32>(idx.expr);
+          idx.is_signed = false;
+        } else {
+          idx.u32 = static_cast<uint32_t>(std::max(idx.i32, 0));
+          idx.is_signed = false;
+        }
+      }
+
+      // Convert idx and limit to expressions, so we can emit `min(idx, limit)`.
+      if (!idx.expr) {
+        idx.expr = b.Expr(idx.u32);
+      }
+      if (!limit.expr) {
+        limit.expr = b.Expr(limit.u32);
+      }
+
+      // Perform the clamp with `min(idx, limit)`
+      idx.expr = b.Call("min", idx.expr, limit.expr);
+    } else {
+      // Both idx and max are constant.
+      if (idx.is_signed) {
+        // The index is signed. Calculate limit as signed.
+        int32_t signed_limit = static_cast<int32_t>(
+            std::min<uint32_t>(limit.u32, std::numeric_limits<int32_t>::max()));
+        idx.i32 = std::max(idx.i32, 0);
+        idx.i32 = std::min(idx.i32, signed_limit);
+      } else {
+        // The index is unsigned.
+        idx.u32 = std::min(idx.u32, limit.u32);
+      }
+    }
+
+    // Convert idx to an expression, so we can emit the new accessor.
+    if (!idx.expr) {
+      idx.expr = idx.is_signed ? b.Expr(idx.i32) : b.Expr(idx.u32);
     }
 
     // Clone arguments outside of create() call to have deterministic ordering
     auto src = ctx.Clone(expr->source());
     auto* arr = ctx.Clone(expr->array());
-    return b.IndexAccessor(src, arr, new_idx);
+    return b.IndexAccessor(src, arr, idx.expr);
   }
 
   /// @param type intrinsic type
diff --git a/src/transform/robustness_test.cc b/src/transform/robustness_test.cc
index f5e2f3f..2e1aab6 100644
--- a/src/transform/robustness_test.cc
+++ b/src/transform/robustness_test.cc
@@ -22,7 +22,7 @@
 
 using RobustnessTest = TransformTest;
 
-TEST_F(RobustnessTest, Ptrs_Clamp) {
+TEST_F(RobustnessTest, Array_Idx_Clamp) {
   auto* src = R"(
 var<private> a : array<f32, 3>;
 
@@ -39,7 +39,7 @@
 let c : u32 = 1u;
 
 fn f() {
-  let b : f32 = a[min(u32(c), 2u)];
+  let b : f32 = a[min(c, 2u)];
 }
 )";
 
@@ -69,7 +69,7 @@
 var<private> i : u32;
 
 fn f() {
-  var c : f32 = a[min(u32(b[min(u32(i), 4u)]), 2u)];
+  var c : f32 = a[min(u32(b[min(i, 4u)]), 2u)];
 }
 )";
 
@@ -170,6 +170,86 @@
   EXPECT_EQ(expect, str(got));
 }
 
+TEST_F(RobustnessTest, LargeArrays_Idx) {
+  auto* src = R"(
+[[block]]
+struct S {
+  a : array<f32, 0x7fffffff>;
+  b : array<f32>;
+};
+[[group(0), binding(0)]] var<storage, read> s : S;
+
+fn f() {
+  // Signed
+  var i32_a1 : f32 = s.a[ 0x7ffffffe];
+  var i32_a2 : f32 = s.a[ 1];
+  var i32_a3 : f32 = s.a[ 0];
+  var i32_a4 : f32 = s.a[-1];
+  var i32_a5 : f32 = s.a[-0x7fffffff];
+
+  var i32_b1 : f32 = s.b[ 0x7ffffffe];
+  var i32_b2 : f32 = s.b[ 1];
+  var i32_b3 : f32 = s.b[ 0];
+  var i32_b4 : f32 = s.b[-1];
+  var i32_b5 : f32 = s.b[-0x7fffffff];
+
+  // Unsigned
+  var u32_a1 : f32 = s.a[0u];
+  var u32_a2 : f32 = s.a[1u];
+  var u32_a3 : f32 = s.a[0x7ffffffeu];
+  var u32_a4 : f32 = s.a[0x7fffffffu];
+  var u32_a5 : f32 = s.a[0x80000000u];
+  var u32_a6 : f32 = s.a[0xffffffffu];
+
+  var u32_b1 : f32 = s.b[0u];
+  var u32_b2 : f32 = s.b[1u];
+  var u32_b3 : f32 = s.b[0x7ffffffeu];
+  var u32_b4 : f32 = s.b[0x7fffffffu];
+  var u32_b5 : f32 = s.b[0x80000000u];
+  var u32_b6 : f32 = s.b[0xffffffffu];
+}
+)";
+
+  auto* expect = R"(
+[[block]]
+struct S {
+  a : array<f32, 2147483647>;
+  b : array<f32>;
+};
+
+[[group(0), binding(0)]] var<storage, read> s : S;
+
+fn f() {
+  var i32_a1 : f32 = s.a[2147483646];
+  var i32_a2 : f32 = s.a[1];
+  var i32_a3 : f32 = s.a[0];
+  var i32_a4 : f32 = s.a[0];
+  var i32_a5 : f32 = s.a[0];
+  var i32_b1 : f32 = s.b[min(2147483646u, (arrayLength(&(s.b)) - 1u))];
+  var i32_b2 : f32 = s.b[min(1u, (arrayLength(&(s.b)) - 1u))];
+  var i32_b3 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
+  var i32_b4 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
+  var i32_b5 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
+  var u32_a1 : f32 = s.a[0u];
+  var u32_a2 : f32 = s.a[1u];
+  var u32_a3 : f32 = s.a[2147483646u];
+  var u32_a4 : f32 = s.a[2147483646u];
+  var u32_a5 : f32 = s.a[2147483646u];
+  var u32_a6 : f32 = s.a[2147483646u];
+  var u32_b1 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
+  var u32_b2 : f32 = s.b[min(1u, (arrayLength(&(s.b)) - 1u))];
+  var u32_b3 : f32 = s.b[min(2147483646u, (arrayLength(&(s.b)) - 1u))];
+  var u32_b4 : f32 = s.b[min(2147483647u, (arrayLength(&(s.b)) - 1u))];
+  var u32_b5 : f32 = s.b[min(2147483648u, (arrayLength(&(s.b)) - 1u))];
+  var u32_b6 : f32 = s.b[min(4294967295u, (arrayLength(&(s.b)) - 1u))];
+}
+)";
+
+  auto got = Run<Robustness>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
 TEST_F(RobustnessTest, Vector_Idx_Scalar) {
   auto* src = R"(
 var<private> a : vec3<f32>;
@@ -557,7 +637,7 @@
 [[group(0), binding(0)]] var<storage, read> s : S;
 
 fn f() {
-  var d : f32 = s.b[min(u32(25), (arrayLength(&(s.b)) - 1u))];
+  var d : f32 = s.b[min(25u, (arrayLength(&(s.b)) - 1u))];
 }
 )";
 
@@ -744,7 +824,7 @@
 let c : u32 = 1u;
 
 fn f() {
-  let b : f32 = s.b[min(u32(c), (arrayLength(&(s.b)) - 1u))];
+  let b : f32 = s.b[min(c, (arrayLength(&(s.b)) - 1u))];
   let x : i32 = min(1, 2);
   let y : u32 = arrayLength(&(s.b));
 }