Add knob for omitting certain storage classes in Robustness transform
BUG=tint:779
Change-Id: Ibcedb998671dd2bf189cc795299ea92846196ade
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/66780
Auto-Submit: Ryan Harrison <rharrison@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/transform/robustness.cc b/src/transform/robustness.cc
index 9aa5a3e..6bd7678 100644
--- a/src/transform/robustness.cc
+++ b/src/transform/robustness.cc
@@ -22,9 +22,11 @@
#include "src/sem/block_statement.h"
#include "src/sem/call.h"
#include "src/sem/expression.h"
+#include "src/sem/reference_type.h"
#include "src/sem/statement.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::Robustness);
+TINT_INSTANTIATE_TYPEINFO(tint::transform::Robustness::Config);
namespace tint {
namespace transform {
@@ -34,6 +36,9 @@
/// The clone context
CloneContext& ctx;
+ /// Set of storage classes to not apply the transform to
+ std::unordered_set<ast::StorageClass> omitted_classes;
+
/// Applies the transformation state to `ctx`.
void Transform() {
ctx.ReplaceAll(
@@ -46,7 +51,14 @@
/// @return the clamped replacement expression, or nullptr if `expr` should be
/// cloned without changes.
ast::ArrayAccessorExpression* Transform(ast::ArrayAccessorExpression* expr) {
- auto* ret_type = ctx.src->Sem().Get(expr->array)->Type()->UnwrapRef();
+ auto* ret_type = ctx.src->Sem().Get(expr->array)->Type();
+
+ auto* ref = ret_type->As<sem::Reference>();
+ if (ref && omitted_classes.count(ref->StorageClass()) != 0) {
+ return nullptr;
+ }
+
+ auto* ret_unwrapped = ret_type->UnwrapRef();
ProgramBuilder& b = *ctx.dst;
using u32 = ProgramBuilder::u32;
@@ -62,12 +74,12 @@
Value size; // size of the array, vector or matrix
size.is_signed = false; // size is always unsigned
- if (auto* vec = ret_type->As<sem::Vector>()) {
+ if (auto* vec = ret_unwrapped->As<sem::Vector>()) {
size.u32 = vec->Width();
- } else if (auto* arr = ret_type->As<sem::Array>()) {
+ } else if (auto* arr = ret_unwrapped->As<sem::Array>()) {
size.u32 = arr->Count();
- } else if (auto* mat = ret_type->As<sem::Matrix>()) {
+ } else if (auto* mat = ret_unwrapped->As<sem::Matrix>()) {
// The row accessor would have been an embedded array accessor and already
// handled, so we just need to do columns here.
size.u32 = mat->columns();
@@ -76,7 +88,7 @@
}
if (size.u32 == 0) {
- if (!ret_type->Is<sem::Array>()) {
+ if (!ret_unwrapped->Is<sem::Array>()) {
b.Diagnostics().add_error(diag::System::Transform,
"invalid 0 sized non-array", expr->source);
return nullptr;
@@ -268,11 +280,34 @@
}
};
+Robustness::Config::Config() = default;
+Robustness::Config::Config(const Config&) = default;
+Robustness::Config::~Config() = default;
+Robustness::Config& Robustness::Config::operator=(const Config&) = default;
+
Robustness::Robustness() = default;
Robustness::~Robustness() = default;
-void Robustness::Run(CloneContext& ctx, const DataMap&, DataMap&) {
- State state{ctx};
+void Robustness::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) {
+ Config cfg;
+ if (auto* cfg_data = inputs.Get<Config>()) {
+ cfg = *cfg_data;
+ }
+
+ std::unordered_set<ast::StorageClass> omitted_classes;
+ for (auto sc : cfg.omitted_classes) {
+ switch (sc) {
+ case StorageClass::kUniform:
+ omitted_classes.insert(ast::StorageClass::kUniform);
+ break;
+ case StorageClass::kStorage:
+ omitted_classes.insert(ast::StorageClass::kStorage);
+ break;
+ }
+ }
+
+ State state{ctx, std::move(omitted_classes)};
+
state.Transform();
ctx.Clone();
}
diff --git a/src/transform/robustness.h b/src/transform/robustness.h
index fcade3a..1333e5c 100644
--- a/src/transform/robustness.h
+++ b/src/transform/robustness.h
@@ -15,6 +15,8 @@
#ifndef SRC_TRANSFORM_ROBUSTNESS_H_
#define SRC_TRANSFORM_ROBUSTNESS_H_
+#include <unordered_set>
+
#include "src/transform/transform.h"
// Forward declarations
@@ -34,6 +36,32 @@
/// (array length - 1).
class Robustness : public Castable<Robustness, Transform> {
public:
+ /// Storage class to be skipped in the transform
+ enum class StorageClass {
+ kUniform,
+ kStorage,
+ };
+
+ /// Configuration options for the transform
+ struct Config : public Castable<Config, Data> {
+ /// Constructor
+ Config();
+
+ /// Copy constructor
+ Config(const Config&);
+
+ /// Destructor
+ ~Config() override;
+
+ /// Assignment operator
+ /// @returns this Config
+ Config& operator=(const Config&);
+
+ /// Storage classes to omit from apply the transform to.
+ /// This allows for optimizing on hardware that provide safe accesses.
+ std::unordered_set<StorageClass> omitted_classes;
+ };
+
/// Constructor
Robustness();
/// Destructor
diff --git a/src/transform/robustness_test.cc b/src/transform/robustness_test.cc
index a6612be..a1cf043 100644
--- a/src/transform/robustness_test.cc
+++ b/src/transform/robustness_test.cc
@@ -818,6 +818,331 @@
EXPECT_EQ(expect, str(got));
}
+const char* kOmitSourceShader = R"(
+[[block]]
+struct S {
+ a : array<f32, 4>;
+ b : array<f32>;
+};
+[[group(0), binding(0)]] var<storage, read> s : S;
+
+type UArr = [[stride(16)]] array<f32, 4>;
+[[block]] struct U {
+ a : UArr;
+};
+[[group(1), binding(0)]] var<uniform> u : U;
+
+fn f() {
+ // Signed
+ var i32_sa1 : f32 = s.a[4];
+ var i32_sa2 : f32 = s.a[1];
+ var i32_sa3 : f32 = s.a[0];
+ var i32_sa4 : f32 = s.a[-1];
+ var i32_sa5 : f32 = s.a[-4];
+
+ var i32_sb1 : f32 = s.b[4];
+ var i32_sb2 : f32 = s.b[1];
+ var i32_sb3 : f32 = s.b[0];
+ var i32_sb4 : f32 = s.b[-1];
+ var i32_sb5 : f32 = s.b[-4];
+
+ var i32_ua1 : f32 = u.a[4];
+ var i32_ua2 : f32 = u.a[1];
+ var i32_ua3 : f32 = u.a[0];
+ var i32_ua4 : f32 = u.a[-1];
+ var i32_ua5 : f32 = u.a[-4];
+
+ // Unsigned
+ var u32_sa1 : f32 = s.a[0u];
+ var u32_sa2 : f32 = s.a[1u];
+ var u32_sa3 : f32 = s.a[3u];
+ var u32_sa4 : f32 = s.a[4u];
+ var u32_sa5 : f32 = s.a[10u];
+ var u32_sa6 : f32 = s.a[100u];
+
+ var u32_sb1 : f32 = s.b[0u];
+ var u32_sb2 : f32 = s.b[1u];
+ var u32_sb3 : f32 = s.b[3u];
+ var u32_sb4 : f32 = s.b[4u];
+ var u32_sb5 : f32 = s.b[10u];
+ var u32_sb6 : f32 = s.b[100u];
+
+ var u32_ua1 : f32 = u.a[0u];
+ var u32_ua2 : f32 = u.a[1u];
+ var u32_ua3 : f32 = u.a[3u];
+ var u32_ua4 : f32 = u.a[4u];
+ var u32_ua5 : f32 = u.a[10u];
+ var u32_ua6 : f32 = u.a[100u];
+}
+)";
+
+TEST_F(RobustnessTest, OmitNone) {
+ auto* expect = R"(
+[[block]]
+struct S {
+ a : array<f32, 4>;
+ b : array<f32>;
+};
+
+[[group(0), binding(0)]] var<storage, read> s : S;
+
+type UArr = [[stride(16)]] array<f32, 4>;
+
+[[block]]
+struct U {
+ a : UArr;
+};
+
+[[group(1), binding(0)]] var<uniform> u : U;
+
+fn f() {
+ var i32_sa1 : f32 = s.a[3];
+ var i32_sa2 : f32 = s.a[1];
+ var i32_sa3 : f32 = s.a[0];
+ var i32_sa4 : f32 = s.a[0];
+ var i32_sa5 : f32 = s.a[0];
+ var i32_sb1 : f32 = s.b[min(4u, (arrayLength(&(s.b)) - 1u))];
+ var i32_sb2 : f32 = s.b[min(1u, (arrayLength(&(s.b)) - 1u))];
+ var i32_sb3 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
+ var i32_sb4 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
+ var i32_sb5 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
+ var i32_ua1 : f32 = u.a[3];
+ var i32_ua2 : f32 = u.a[1];
+ var i32_ua3 : f32 = u.a[0];
+ var i32_ua4 : f32 = u.a[0];
+ var i32_ua5 : f32 = u.a[0];
+ var u32_sa1 : f32 = s.a[0u];
+ var u32_sa2 : f32 = s.a[1u];
+ var u32_sa3 : f32 = s.a[3u];
+ var u32_sa4 : f32 = s.a[3u];
+ var u32_sa5 : f32 = s.a[3u];
+ var u32_sa6 : f32 = s.a[3u];
+ var u32_sb1 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
+ var u32_sb2 : f32 = s.b[min(1u, (arrayLength(&(s.b)) - 1u))];
+ var u32_sb3 : f32 = s.b[min(3u, (arrayLength(&(s.b)) - 1u))];
+ var u32_sb4 : f32 = s.b[min(4u, (arrayLength(&(s.b)) - 1u))];
+ var u32_sb5 : f32 = s.b[min(10u, (arrayLength(&(s.b)) - 1u))];
+ var u32_sb6 : f32 = s.b[min(100u, (arrayLength(&(s.b)) - 1u))];
+ var u32_ua1 : f32 = u.a[0u];
+ var u32_ua2 : f32 = u.a[1u];
+ var u32_ua3 : f32 = u.a[3u];
+ var u32_ua4 : f32 = u.a[3u];
+ var u32_ua5 : f32 = u.a[3u];
+ var u32_ua6 : f32 = u.a[3u];
+}
+)";
+
+ Robustness::Config cfg;
+ DataMap data;
+ data.Add<Robustness::Config>(cfg);
+
+ auto got = Run<Robustness>(kOmitSourceShader, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RobustnessTest, OmitStorage) {
+ auto* expect = R"(
+[[block]]
+struct S {
+ a : array<f32, 4>;
+ b : array<f32>;
+};
+
+[[group(0), binding(0)]] var<storage, read> s : S;
+
+type UArr = [[stride(16)]] array<f32, 4>;
+
+[[block]]
+struct U {
+ a : UArr;
+};
+
+[[group(1), binding(0)]] var<uniform> u : U;
+
+fn f() {
+ var i32_sa1 : f32 = s.a[4];
+ var i32_sa2 : f32 = s.a[1];
+ var i32_sa3 : f32 = s.a[0];
+ var i32_sa4 : f32 = s.a[-1];
+ var i32_sa5 : f32 = s.a[-4];
+ var i32_sb1 : f32 = s.b[4];
+ var i32_sb2 : f32 = s.b[1];
+ var i32_sb3 : f32 = s.b[0];
+ var i32_sb4 : f32 = s.b[-1];
+ var i32_sb5 : f32 = s.b[-4];
+ var i32_ua1 : f32 = u.a[3];
+ var i32_ua2 : f32 = u.a[1];
+ var i32_ua3 : f32 = u.a[0];
+ var i32_ua4 : f32 = u.a[0];
+ var i32_ua5 : f32 = u.a[0];
+ var u32_sa1 : f32 = s.a[0u];
+ var u32_sa2 : f32 = s.a[1u];
+ var u32_sa3 : f32 = s.a[3u];
+ var u32_sa4 : f32 = s.a[4u];
+ var u32_sa5 : f32 = s.a[10u];
+ var u32_sa6 : f32 = s.a[100u];
+ var u32_sb1 : f32 = s.b[0u];
+ var u32_sb2 : f32 = s.b[1u];
+ var u32_sb3 : f32 = s.b[3u];
+ var u32_sb4 : f32 = s.b[4u];
+ var u32_sb5 : f32 = s.b[10u];
+ var u32_sb6 : f32 = s.b[100u];
+ var u32_ua1 : f32 = u.a[0u];
+ var u32_ua2 : f32 = u.a[1u];
+ var u32_ua3 : f32 = u.a[3u];
+ var u32_ua4 : f32 = u.a[3u];
+ var u32_ua5 : f32 = u.a[3u];
+ var u32_ua6 : f32 = u.a[3u];
+}
+)";
+
+ Robustness::Config cfg;
+ cfg.omitted_classes.insert(Robustness::StorageClass::kStorage);
+
+ DataMap data;
+ data.Add<Robustness::Config>(cfg);
+
+ auto got = Run<Robustness>(kOmitSourceShader, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RobustnessTest, OmitUniform) {
+ auto* expect = R"(
+[[block]]
+struct S {
+ a : array<f32, 4>;
+ b : array<f32>;
+};
+
+[[group(0), binding(0)]] var<storage, read> s : S;
+
+type UArr = [[stride(16)]] array<f32, 4>;
+
+[[block]]
+struct U {
+ a : UArr;
+};
+
+[[group(1), binding(0)]] var<uniform> u : U;
+
+fn f() {
+ var i32_sa1 : f32 = s.a[3];
+ var i32_sa2 : f32 = s.a[1];
+ var i32_sa3 : f32 = s.a[0];
+ var i32_sa4 : f32 = s.a[0];
+ var i32_sa5 : f32 = s.a[0];
+ var i32_sb1 : f32 = s.b[min(4u, (arrayLength(&(s.b)) - 1u))];
+ var i32_sb2 : f32 = s.b[min(1u, (arrayLength(&(s.b)) - 1u))];
+ var i32_sb3 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
+ var i32_sb4 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
+ var i32_sb5 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
+ var i32_ua1 : f32 = u.a[4];
+ var i32_ua2 : f32 = u.a[1];
+ var i32_ua3 : f32 = u.a[0];
+ var i32_ua4 : f32 = u.a[-1];
+ var i32_ua5 : f32 = u.a[-4];
+ var u32_sa1 : f32 = s.a[0u];
+ var u32_sa2 : f32 = s.a[1u];
+ var u32_sa3 : f32 = s.a[3u];
+ var u32_sa4 : f32 = s.a[3u];
+ var u32_sa5 : f32 = s.a[3u];
+ var u32_sa6 : f32 = s.a[3u];
+ var u32_sb1 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
+ var u32_sb2 : f32 = s.b[min(1u, (arrayLength(&(s.b)) - 1u))];
+ var u32_sb3 : f32 = s.b[min(3u, (arrayLength(&(s.b)) - 1u))];
+ var u32_sb4 : f32 = s.b[min(4u, (arrayLength(&(s.b)) - 1u))];
+ var u32_sb5 : f32 = s.b[min(10u, (arrayLength(&(s.b)) - 1u))];
+ var u32_sb6 : f32 = s.b[min(100u, (arrayLength(&(s.b)) - 1u))];
+ var u32_ua1 : f32 = u.a[0u];
+ var u32_ua2 : f32 = u.a[1u];
+ var u32_ua3 : f32 = u.a[3u];
+ var u32_ua4 : f32 = u.a[4u];
+ var u32_ua5 : f32 = u.a[10u];
+ var u32_ua6 : f32 = u.a[100u];
+}
+)";
+
+ Robustness::Config cfg;
+ cfg.omitted_classes.insert(Robustness::StorageClass::kUniform);
+
+ DataMap data;
+ data.Add<Robustness::Config>(cfg);
+
+ auto got = Run<Robustness>(kOmitSourceShader, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(RobustnessTest, OmitBoth) {
+ auto* expect = R"(
+[[block]]
+struct S {
+ a : array<f32, 4>;
+ b : array<f32>;
+};
+
+[[group(0), binding(0)]] var<storage, read> s : S;
+
+type UArr = [[stride(16)]] array<f32, 4>;
+
+[[block]]
+struct U {
+ a : UArr;
+};
+
+[[group(1), binding(0)]] var<uniform> u : U;
+
+fn f() {
+ var i32_sa1 : f32 = s.a[4];
+ var i32_sa2 : f32 = s.a[1];
+ var i32_sa3 : f32 = s.a[0];
+ var i32_sa4 : f32 = s.a[-1];
+ var i32_sa5 : f32 = s.a[-4];
+ var i32_sb1 : f32 = s.b[4];
+ var i32_sb2 : f32 = s.b[1];
+ var i32_sb3 : f32 = s.b[0];
+ var i32_sb4 : f32 = s.b[-1];
+ var i32_sb5 : f32 = s.b[-4];
+ var i32_ua1 : f32 = u.a[4];
+ var i32_ua2 : f32 = u.a[1];
+ var i32_ua3 : f32 = u.a[0];
+ var i32_ua4 : f32 = u.a[-1];
+ var i32_ua5 : f32 = u.a[-4];
+ var u32_sa1 : f32 = s.a[0u];
+ var u32_sa2 : f32 = s.a[1u];
+ var u32_sa3 : f32 = s.a[3u];
+ var u32_sa4 : f32 = s.a[4u];
+ var u32_sa5 : f32 = s.a[10u];
+ var u32_sa6 : f32 = s.a[100u];
+ var u32_sb1 : f32 = s.b[0u];
+ var u32_sb2 : f32 = s.b[1u];
+ var u32_sb3 : f32 = s.b[3u];
+ var u32_sb4 : f32 = s.b[4u];
+ var u32_sb5 : f32 = s.b[10u];
+ var u32_sb6 : f32 = s.b[100u];
+ var u32_ua1 : f32 = u.a[0u];
+ var u32_ua2 : f32 = u.a[1u];
+ var u32_ua3 : f32 = u.a[3u];
+ var u32_ua4 : f32 = u.a[4u];
+ var u32_ua5 : f32 = u.a[10u];
+ var u32_ua6 : f32 = u.a[100u];
+}
+)";
+
+ Robustness::Config cfg;
+ cfg.omitted_classes.insert(Robustness::StorageClass::kStorage);
+ cfg.omitted_classes.insert(Robustness::StorageClass::kUniform);
+
+ DataMap data;
+ data.Add<Robustness::Config>(cfg);
+
+ auto got = Run<Robustness>(kOmitSourceShader, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
} // namespace
} // namespace transform
} // namespace tint