Add knob for omitting certain storage classes in Robustness transform

BUG=tint:779

Change-Id: Ibcedb998671dd2bf189cc795299ea92846196ade
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/66780
Auto-Submit: Ryan Harrison <rharrison@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/transform/robustness.cc b/src/transform/robustness.cc
index 9aa5a3e..6bd7678 100644
--- a/src/transform/robustness.cc
+++ b/src/transform/robustness.cc
@@ -22,9 +22,11 @@
 #include "src/sem/block_statement.h"
 #include "src/sem/call.h"
 #include "src/sem/expression.h"
+#include "src/sem/reference_type.h"
 #include "src/sem/statement.h"
 
 TINT_INSTANTIATE_TYPEINFO(tint::transform::Robustness);
+TINT_INSTANTIATE_TYPEINFO(tint::transform::Robustness::Config);
 
 namespace tint {
 namespace transform {
@@ -34,6 +36,9 @@
   /// The clone context
   CloneContext& ctx;
 
+  /// Set of storage classes to not apply the transform to
+  std::unordered_set<ast::StorageClass> omitted_classes;
+
   /// Applies the transformation state to `ctx`.
   void Transform() {
     ctx.ReplaceAll(
@@ -46,7 +51,14 @@
   /// @return the clamped replacement expression, or nullptr if `expr` should be
   /// cloned without changes.
   ast::ArrayAccessorExpression* Transform(ast::ArrayAccessorExpression* expr) {
-    auto* ret_type = ctx.src->Sem().Get(expr->array)->Type()->UnwrapRef();
+    auto* ret_type = ctx.src->Sem().Get(expr->array)->Type();
+
+    auto* ref = ret_type->As<sem::Reference>();
+    if (ref && omitted_classes.count(ref->StorageClass()) != 0) {
+      return nullptr;
+    }
+
+    auto* ret_unwrapped = ret_type->UnwrapRef();
 
     ProgramBuilder& b = *ctx.dst;
     using u32 = ProgramBuilder::u32;
@@ -62,12 +74,12 @@
 
     Value size;              // size of the array, vector or matrix
     size.is_signed = false;  // size is always unsigned
-    if (auto* vec = ret_type->As<sem::Vector>()) {
+    if (auto* vec = ret_unwrapped->As<sem::Vector>()) {
       size.u32 = vec->Width();
 
-    } else if (auto* arr = ret_type->As<sem::Array>()) {
+    } else if (auto* arr = ret_unwrapped->As<sem::Array>()) {
       size.u32 = arr->Count();
-    } else if (auto* mat = ret_type->As<sem::Matrix>()) {
+    } else if (auto* mat = ret_unwrapped->As<sem::Matrix>()) {
       // The row accessor would have been an embedded array accessor and already
       // handled, so we just need to do columns here.
       size.u32 = mat->columns();
@@ -76,7 +88,7 @@
     }
 
     if (size.u32 == 0) {
-      if (!ret_type->Is<sem::Array>()) {
+      if (!ret_unwrapped->Is<sem::Array>()) {
         b.Diagnostics().add_error(diag::System::Transform,
                                   "invalid 0 sized non-array", expr->source);
         return nullptr;
@@ -268,11 +280,34 @@
   }
 };
 
+Robustness::Config::Config() = default;
+Robustness::Config::Config(const Config&) = default;
+Robustness::Config::~Config() = default;
+Robustness::Config& Robustness::Config::operator=(const Config&) = default;
+
 Robustness::Robustness() = default;
 Robustness::~Robustness() = default;
 
-void Robustness::Run(CloneContext& ctx, const DataMap&, DataMap&) {
-  State state{ctx};
+void Robustness::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) {
+  Config cfg;
+  if (auto* cfg_data = inputs.Get<Config>()) {
+    cfg = *cfg_data;
+  }
+
+  std::unordered_set<ast::StorageClass> omitted_classes;
+  for (auto sc : cfg.omitted_classes) {
+    switch (sc) {
+      case StorageClass::kUniform:
+        omitted_classes.insert(ast::StorageClass::kUniform);
+        break;
+      case StorageClass::kStorage:
+        omitted_classes.insert(ast::StorageClass::kStorage);
+        break;
+    }
+  }
+
+  State state{ctx, std::move(omitted_classes)};
+
   state.Transform();
   ctx.Clone();
 }
diff --git a/src/transform/robustness.h b/src/transform/robustness.h
index fcade3a..1333e5c 100644
--- a/src/transform/robustness.h
+++ b/src/transform/robustness.h
@@ -15,6 +15,8 @@
 #ifndef SRC_TRANSFORM_ROBUSTNESS_H_
 #define SRC_TRANSFORM_ROBUSTNESS_H_
 
+#include <unordered_set>
+
 #include "src/transform/transform.h"
 
 // Forward declarations
@@ -34,6 +36,32 @@
 /// (array length - 1).
 class Robustness : public Castable<Robustness, Transform> {
  public:
+  /// Storage class to be skipped in the transform
+  enum class StorageClass {
+    kUniform,
+    kStorage,
+  };
+
+  /// Configuration options for the transform
+  struct Config : public Castable<Config, Data> {
+    /// Constructor
+    Config();
+
+    /// Copy constructor
+    Config(const Config&);
+
+    /// Destructor
+    ~Config() override;
+
+    /// Assignment operator
+    /// @returns this Config
+    Config& operator=(const Config&);
+
+    /// Storage classes to omit from apply the transform to.
+    /// This allows for optimizing on hardware that provide safe accesses.
+    std::unordered_set<StorageClass> omitted_classes;
+  };
+
   /// Constructor
   Robustness();
   /// Destructor
diff --git a/src/transform/robustness_test.cc b/src/transform/robustness_test.cc
index a6612be..a1cf043 100644
--- a/src/transform/robustness_test.cc
+++ b/src/transform/robustness_test.cc
@@ -818,6 +818,331 @@
   EXPECT_EQ(expect, str(got));
 }
 
+const char* kOmitSourceShader = R"(
+[[block]]
+struct S {
+  a : array<f32, 4>;
+  b : array<f32>;
+};
+[[group(0), binding(0)]] var<storage, read> s : S;
+
+type UArr = [[stride(16)]] array<f32, 4>;
+[[block]] struct U {
+  a : UArr;
+};
+[[group(1), binding(0)]] var<uniform> u : U;
+
+fn f() {
+  // Signed
+  var i32_sa1 : f32 = s.a[4];
+  var i32_sa2 : f32 = s.a[1];
+  var i32_sa3 : f32 = s.a[0];
+  var i32_sa4 : f32 = s.a[-1];
+  var i32_sa5 : f32 = s.a[-4];
+
+  var i32_sb1 : f32 = s.b[4];
+  var i32_sb2 : f32 = s.b[1];
+  var i32_sb3 : f32 = s.b[0];
+  var i32_sb4 : f32 = s.b[-1];
+  var i32_sb5 : f32 = s.b[-4];
+
+  var i32_ua1 : f32 = u.a[4];
+  var i32_ua2 : f32 = u.a[1];
+  var i32_ua3 : f32 = u.a[0];
+  var i32_ua4 : f32 = u.a[-1];
+  var i32_ua5 : f32 = u.a[-4];
+
+  // Unsigned
+  var u32_sa1 : f32 = s.a[0u];
+  var u32_sa2 : f32 = s.a[1u];
+  var u32_sa3 : f32 = s.a[3u];
+  var u32_sa4 : f32 = s.a[4u];
+  var u32_sa5 : f32 = s.a[10u];
+  var u32_sa6 : f32 = s.a[100u];
+
+  var u32_sb1 : f32 = s.b[0u];
+  var u32_sb2 : f32 = s.b[1u];
+  var u32_sb3 : f32 = s.b[3u];
+  var u32_sb4 : f32 = s.b[4u];
+  var u32_sb5 : f32 = s.b[10u];
+  var u32_sb6 : f32 = s.b[100u];
+
+  var u32_ua1 : f32 = u.a[0u];
+  var u32_ua2 : f32 = u.a[1u];
+  var u32_ua3 : f32 = u.a[3u];
+  var u32_ua4 : f32 = u.a[4u];
+  var u32_ua5 : f32 = u.a[10u];
+  var u32_ua6 : f32 = u.a[100u];
+}
+)";
+
+TEST_F(RobustnessTest, OmitNone) {
+  auto* expect = R"(
+[[block]]
+struct S {
+  a : array<f32, 4>;
+  b : array<f32>;
+};
+
+[[group(0), binding(0)]] var<storage, read> s : S;
+
+type UArr = [[stride(16)]] array<f32, 4>;
+
+[[block]]
+struct U {
+  a : UArr;
+};
+
+[[group(1), binding(0)]] var<uniform> u : U;
+
+fn f() {
+  var i32_sa1 : f32 = s.a[3];
+  var i32_sa2 : f32 = s.a[1];
+  var i32_sa3 : f32 = s.a[0];
+  var i32_sa4 : f32 = s.a[0];
+  var i32_sa5 : f32 = s.a[0];
+  var i32_sb1 : f32 = s.b[min(4u, (arrayLength(&(s.b)) - 1u))];
+  var i32_sb2 : f32 = s.b[min(1u, (arrayLength(&(s.b)) - 1u))];
+  var i32_sb3 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
+  var i32_sb4 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
+  var i32_sb5 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
+  var i32_ua1 : f32 = u.a[3];
+  var i32_ua2 : f32 = u.a[1];
+  var i32_ua3 : f32 = u.a[0];
+  var i32_ua4 : f32 = u.a[0];
+  var i32_ua5 : f32 = u.a[0];
+  var u32_sa1 : f32 = s.a[0u];
+  var u32_sa2 : f32 = s.a[1u];
+  var u32_sa3 : f32 = s.a[3u];
+  var u32_sa4 : f32 = s.a[3u];
+  var u32_sa5 : f32 = s.a[3u];
+  var u32_sa6 : f32 = s.a[3u];
+  var u32_sb1 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
+  var u32_sb2 : f32 = s.b[min(1u, (arrayLength(&(s.b)) - 1u))];
+  var u32_sb3 : f32 = s.b[min(3u, (arrayLength(&(s.b)) - 1u))];
+  var u32_sb4 : f32 = s.b[min(4u, (arrayLength(&(s.b)) - 1u))];
+  var u32_sb5 : f32 = s.b[min(10u, (arrayLength(&(s.b)) - 1u))];
+  var u32_sb6 : f32 = s.b[min(100u, (arrayLength(&(s.b)) - 1u))];
+  var u32_ua1 : f32 = u.a[0u];
+  var u32_ua2 : f32 = u.a[1u];
+  var u32_ua3 : f32 = u.a[3u];
+  var u32_ua4 : f32 = u.a[3u];
+  var u32_ua5 : f32 = u.a[3u];
+  var u32_ua6 : f32 = u.a[3u];
+}
+)";
+
+  Robustness::Config cfg;
+  DataMap data;
+  data.Add<Robustness::Config>(cfg);
+
+  auto got = Run<Robustness>(kOmitSourceShader, data);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RobustnessTest, OmitStorage) {
+  auto* expect = R"(
+[[block]]
+struct S {
+  a : array<f32, 4>;
+  b : array<f32>;
+};
+
+[[group(0), binding(0)]] var<storage, read> s : S;
+
+type UArr = [[stride(16)]] array<f32, 4>;
+
+[[block]]
+struct U {
+  a : UArr;
+};
+
+[[group(1), binding(0)]] var<uniform> u : U;
+
+fn f() {
+  var i32_sa1 : f32 = s.a[4];
+  var i32_sa2 : f32 = s.a[1];
+  var i32_sa3 : f32 = s.a[0];
+  var i32_sa4 : f32 = s.a[-1];
+  var i32_sa5 : f32 = s.a[-4];
+  var i32_sb1 : f32 = s.b[4];
+  var i32_sb2 : f32 = s.b[1];
+  var i32_sb3 : f32 = s.b[0];
+  var i32_sb4 : f32 = s.b[-1];
+  var i32_sb5 : f32 = s.b[-4];
+  var i32_ua1 : f32 = u.a[3];
+  var i32_ua2 : f32 = u.a[1];
+  var i32_ua3 : f32 = u.a[0];
+  var i32_ua4 : f32 = u.a[0];
+  var i32_ua5 : f32 = u.a[0];
+  var u32_sa1 : f32 = s.a[0u];
+  var u32_sa2 : f32 = s.a[1u];
+  var u32_sa3 : f32 = s.a[3u];
+  var u32_sa4 : f32 = s.a[4u];
+  var u32_sa5 : f32 = s.a[10u];
+  var u32_sa6 : f32 = s.a[100u];
+  var u32_sb1 : f32 = s.b[0u];
+  var u32_sb2 : f32 = s.b[1u];
+  var u32_sb3 : f32 = s.b[3u];
+  var u32_sb4 : f32 = s.b[4u];
+  var u32_sb5 : f32 = s.b[10u];
+  var u32_sb6 : f32 = s.b[100u];
+  var u32_ua1 : f32 = u.a[0u];
+  var u32_ua2 : f32 = u.a[1u];
+  var u32_ua3 : f32 = u.a[3u];
+  var u32_ua4 : f32 = u.a[3u];
+  var u32_ua5 : f32 = u.a[3u];
+  var u32_ua6 : f32 = u.a[3u];
+}
+)";
+
+  Robustness::Config cfg;
+  cfg.omitted_classes.insert(Robustness::StorageClass::kStorage);
+
+  DataMap data;
+  data.Add<Robustness::Config>(cfg);
+
+  auto got = Run<Robustness>(kOmitSourceShader, data);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RobustnessTest, OmitUniform) {
+  auto* expect = R"(
+[[block]]
+struct S {
+  a : array<f32, 4>;
+  b : array<f32>;
+};
+
+[[group(0), binding(0)]] var<storage, read> s : S;
+
+type UArr = [[stride(16)]] array<f32, 4>;
+
+[[block]]
+struct U {
+  a : UArr;
+};
+
+[[group(1), binding(0)]] var<uniform> u : U;
+
+fn f() {
+  var i32_sa1 : f32 = s.a[3];
+  var i32_sa2 : f32 = s.a[1];
+  var i32_sa3 : f32 = s.a[0];
+  var i32_sa4 : f32 = s.a[0];
+  var i32_sa5 : f32 = s.a[0];
+  var i32_sb1 : f32 = s.b[min(4u, (arrayLength(&(s.b)) - 1u))];
+  var i32_sb2 : f32 = s.b[min(1u, (arrayLength(&(s.b)) - 1u))];
+  var i32_sb3 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
+  var i32_sb4 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
+  var i32_sb5 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
+  var i32_ua1 : f32 = u.a[4];
+  var i32_ua2 : f32 = u.a[1];
+  var i32_ua3 : f32 = u.a[0];
+  var i32_ua4 : f32 = u.a[-1];
+  var i32_ua5 : f32 = u.a[-4];
+  var u32_sa1 : f32 = s.a[0u];
+  var u32_sa2 : f32 = s.a[1u];
+  var u32_sa3 : f32 = s.a[3u];
+  var u32_sa4 : f32 = s.a[3u];
+  var u32_sa5 : f32 = s.a[3u];
+  var u32_sa6 : f32 = s.a[3u];
+  var u32_sb1 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
+  var u32_sb2 : f32 = s.b[min(1u, (arrayLength(&(s.b)) - 1u))];
+  var u32_sb3 : f32 = s.b[min(3u, (arrayLength(&(s.b)) - 1u))];
+  var u32_sb4 : f32 = s.b[min(4u, (arrayLength(&(s.b)) - 1u))];
+  var u32_sb5 : f32 = s.b[min(10u, (arrayLength(&(s.b)) - 1u))];
+  var u32_sb6 : f32 = s.b[min(100u, (arrayLength(&(s.b)) - 1u))];
+  var u32_ua1 : f32 = u.a[0u];
+  var u32_ua2 : f32 = u.a[1u];
+  var u32_ua3 : f32 = u.a[3u];
+  var u32_ua4 : f32 = u.a[4u];
+  var u32_ua5 : f32 = u.a[10u];
+  var u32_ua6 : f32 = u.a[100u];
+}
+)";
+
+  Robustness::Config cfg;
+  cfg.omitted_classes.insert(Robustness::StorageClass::kUniform);
+
+  DataMap data;
+  data.Add<Robustness::Config>(cfg);
+
+  auto got = Run<Robustness>(kOmitSourceShader, data);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RobustnessTest, OmitBoth) {
+  auto* expect = R"(
+[[block]]
+struct S {
+  a : array<f32, 4>;
+  b : array<f32>;
+};
+
+[[group(0), binding(0)]] var<storage, read> s : S;
+
+type UArr = [[stride(16)]] array<f32, 4>;
+
+[[block]]
+struct U {
+  a : UArr;
+};
+
+[[group(1), binding(0)]] var<uniform> u : U;
+
+fn f() {
+  var i32_sa1 : f32 = s.a[4];
+  var i32_sa2 : f32 = s.a[1];
+  var i32_sa3 : f32 = s.a[0];
+  var i32_sa4 : f32 = s.a[-1];
+  var i32_sa5 : f32 = s.a[-4];
+  var i32_sb1 : f32 = s.b[4];
+  var i32_sb2 : f32 = s.b[1];
+  var i32_sb3 : f32 = s.b[0];
+  var i32_sb4 : f32 = s.b[-1];
+  var i32_sb5 : f32 = s.b[-4];
+  var i32_ua1 : f32 = u.a[4];
+  var i32_ua2 : f32 = u.a[1];
+  var i32_ua3 : f32 = u.a[0];
+  var i32_ua4 : f32 = u.a[-1];
+  var i32_ua5 : f32 = u.a[-4];
+  var u32_sa1 : f32 = s.a[0u];
+  var u32_sa2 : f32 = s.a[1u];
+  var u32_sa3 : f32 = s.a[3u];
+  var u32_sa4 : f32 = s.a[4u];
+  var u32_sa5 : f32 = s.a[10u];
+  var u32_sa6 : f32 = s.a[100u];
+  var u32_sb1 : f32 = s.b[0u];
+  var u32_sb2 : f32 = s.b[1u];
+  var u32_sb3 : f32 = s.b[3u];
+  var u32_sb4 : f32 = s.b[4u];
+  var u32_sb5 : f32 = s.b[10u];
+  var u32_sb6 : f32 = s.b[100u];
+  var u32_ua1 : f32 = u.a[0u];
+  var u32_ua2 : f32 = u.a[1u];
+  var u32_ua3 : f32 = u.a[3u];
+  var u32_ua4 : f32 = u.a[4u];
+  var u32_ua5 : f32 = u.a[10u];
+  var u32_ua6 : f32 = u.a[100u];
+}
+)";
+
+  Robustness::Config cfg;
+  cfg.omitted_classes.insert(Robustness::StorageClass::kStorage);
+  cfg.omitted_classes.insert(Robustness::StorageClass::kUniform);
+
+  DataMap data;
+  data.Add<Robustness::Config>(cfg);
+
+  auto got = Run<Robustness>(kOmitSourceShader, data);
+
+  EXPECT_EQ(expect, str(got));
+}
+
 }  // namespace
 }  // namespace transform
 }  // namespace tint