Validate array stride

Fix: tint:707
Change-Id: I4439b10f173d8753bd1d407629954a7dad61a679
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47826
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/resolver/decoration_validation_test.cc b/src/resolver/decoration_validation_test.cc
index 70d1b4c..0c021a2 100644
--- a/src/resolver/decoration_validation_test.cc
+++ b/src/resolver/decoration_validation_test.cc
@@ -24,6 +24,9 @@
 #include "gmock/gmock.h"
 
 namespace tint {
+namespace resolver {
+
+namespace DecorationTests {
 namespace {
 
 enum class DecorationKind {
@@ -45,12 +48,11 @@
   DecorationKind kind;
   bool should_pass;
 };
-class TestWithParams : public resolver::TestHelper,
-                       public testing::TestWithParam<TestParams> {};
+struct TestWithParams : ResolverTestWithParam<TestParams> {};
 
-ast::Decoration* createDecoration(const Source& source,
-                                  ProgramBuilder& builder,
-                                  DecorationKind kind) {
+static ast::Decoration* createDecoration(const Source& source,
+                                         ProgramBuilder& builder,
+                                         DecorationKind kind) {
   switch (kind) {
     case DecorationKind::kAccess:
       return builder.create<ast::AccessDecoration>(
@@ -87,7 +89,7 @@
 
 using FunctionReturnTypeDecorationTest = TestWithParams;
 TEST_P(FunctionReturnTypeDecorationTest, IsValid) {
-  auto params = GetParam();
+  auto& params = GetParam();
 
   Func("main", ast::VariableList{}, ty.f32(),
        ast::StatementList{create<ast::ReturnStatement>(Expr(1.f))},
@@ -121,9 +123,8 @@
                     TestParams{DecorationKind::kWorkgroup, false}));
 
 using ArrayDecorationTest = TestWithParams;
-
 TEST_P(ArrayDecorationTest, IsValid) {
-  auto params = GetParam();
+  auto& params = GetParam();
 
   ast::StructMemberList members{Member(
       "a", create<type::Array>(ty.f32(), 0,
@@ -163,7 +164,7 @@
 
 using StructDecorationTest = TestWithParams;
 TEST_P(StructDecorationTest, IsValid) {
-  auto params = GetParam();
+  auto& params = GetParam();
 
   auto* s = create<ast::Struct>(ast::StructMemberList{},
                                 ast::DecorationList{createDecoration(
@@ -200,7 +201,7 @@
 
 using StructMemberDecorationTest = TestWithParams;
 TEST_P(StructMemberDecorationTest, IsValid) {
-  auto params = GetParam();
+  auto& params = GetParam();
 
   ast::StructMemberList members{
       Member("a", ty.i32(),
@@ -239,7 +240,7 @@
 
 using VariableDecorationTest = TestWithParams;
 TEST_P(VariableDecorationTest, IsValid) {
-  auto params = GetParam();
+  auto& params = GetParam();
 
   Global("a", ty.f32(), ast::StorageClass::kInput, nullptr,
          ast::DecorationList{
@@ -274,7 +275,7 @@
 
 using FunctionDecorationTest = TestWithParams;
 TEST_P(FunctionDecorationTest, IsValid) {
-  auto params = GetParam();
+  auto& params = GetParam();
 
   Func("foo", ast::VariableList{}, ty.void_(), ast::StatementList{},
        ast::DecorationList{
@@ -307,4 +308,121 @@
                     TestParams{DecorationKind::kWorkgroup, true}));
 
 }  // namespace
+}  // namespace DecorationTests
+
+namespace ArrayStrideTests {
+namespace {
+
+struct Params {
+  create_type_func_ptr create_el_type;
+  uint32_t stride;
+  bool should_pass;
+};
+
+struct TestWithParams : ResolverTestWithParam<Params> {};
+
+using ArrayStrideTest = TestWithParams;
+TEST_P(ArrayStrideTest, All) {
+  auto& params = GetParam();
+  auto* el_ty = params.create_el_type(ty);
+
+  std::stringstream ss;
+  ss << "el_ty: " << el_ty->FriendlyName(Symbols())
+     << ", stride: " << params.stride
+     << ", should_pass: " << params.should_pass;
+  SCOPED_TRACE(ss.str());
+
+  auto* arr =
+      create<type::Array>(el_ty, 4,
+                          ast::DecorationList{
+                              create<ast::StrideDecoration>(params.stride),
+                          });
+
+  Global(Source{{12, 34}}, "myarray", arr, ast::StorageClass::kInput);
+
+  if (params.should_pass) {
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+  } else {
+    EXPECT_FALSE(r()->Resolve());
+    EXPECT_EQ(r()->error(),
+              "12:34 error: arrays decorated with the stride attribute must "
+              "have a stride that is at least the size of the element type, "
+              "and be a multiple of the element type's alignment value.");
+  }
+}
+
+// Helpers and typedefs
+using i32 = ProgramBuilder::i32;
+using u32 = ProgramBuilder::u32;
+using f32 = ProgramBuilder::f32;
+
+struct SizeAndAlignment {
+  uint32_t size;
+  uint32_t align;
+};
+constexpr SizeAndAlignment default_u32 = {4, 4};
+constexpr SizeAndAlignment default_i32 = {4, 4};
+constexpr SizeAndAlignment default_f32 = {4, 4};
+constexpr SizeAndAlignment default_vec2 = {8, 8};
+constexpr SizeAndAlignment default_vec3 = {12, 16};
+constexpr SizeAndAlignment default_vec4 = {16, 16};
+constexpr SizeAndAlignment default_mat2x2 = {16, 8};
+constexpr SizeAndAlignment default_mat3x3 = {48, 16};
+constexpr SizeAndAlignment default_mat4x4 = {64, 16};
+
+INSTANTIATE_TEST_SUITE_P(
+    ResolverDecorationValidationTest,
+    ArrayStrideTest,
+    testing::Values(
+        // Succeed because stride >= element size (while being multiple of
+        // element alignment)
+        Params{ty_u32, default_u32.size, true},
+        Params{ty_i32, default_i32.size, true},
+        Params{ty_f32, default_f32.size, true},
+        Params{ty_vec2<f32>, default_vec2.size, true},
+        // vec3's default size is not a multiple of its alignment
+        // Params{ty_vec3<f32>, default_vec3.size, true},
+        Params{ty_vec4<f32>, default_vec4.size, true},
+        Params{ty_mat2x2<f32>, default_mat2x2.size, true},
+        Params{ty_mat3x3<f32>, default_mat3x3.size, true},
+        Params{ty_mat4x4<f32>, default_mat4x4.size, true},
+
+        // Fail because stride is < element size
+        Params{ty_u32, default_u32.size - 1, false},
+        Params{ty_i32, default_i32.size - 1, false},
+        Params{ty_f32, default_f32.size - 1, false},
+        Params{ty_vec2<f32>, default_vec2.size - 1, false},
+        Params{ty_vec3<f32>, default_vec3.size - 1, false},
+        Params{ty_vec4<f32>, default_vec4.size - 1, false},
+        Params{ty_mat2x2<f32>, default_mat2x2.size - 1, false},
+        Params{ty_mat3x3<f32>, default_mat3x3.size - 1, false},
+        Params{ty_mat4x4<f32>, default_mat4x4.size - 1, false},
+
+        // Succeed because stride equals multiple of element alignment
+        Params{ty_u32, default_u32.align * 7, true},
+        Params{ty_i32, default_i32.align * 7, true},
+        Params{ty_f32, default_f32.align * 7, true},
+        Params{ty_vec2<f32>, default_vec2.align * 7, true},
+        Params{ty_vec3<f32>, default_vec3.align * 7, true},
+        Params{ty_vec4<f32>, default_vec4.align * 7, true},
+        Params{ty_mat2x2<f32>, default_mat2x2.align * 7, true},
+        Params{ty_mat3x3<f32>, default_mat3x3.align * 7, true},
+        Params{ty_mat4x4<f32>, default_mat4x4.align * 7, true},
+
+        // Fail because stride is not multiple of element alignment
+        Params{ty_u32, (default_u32.align - 1) * 7, false},
+        Params{ty_i32, (default_i32.align - 1) * 7, false},
+        Params{ty_f32, (default_f32.align - 1) * 7, false},
+        Params{ty_vec2<f32>, (default_vec2.align - 1) * 7, false},
+        Params{ty_vec3<f32>, (default_vec3.align - 1) * 7, false},
+        Params{ty_vec4<f32>, (default_vec4.align - 1) * 7, false},
+        Params{ty_mat2x2<f32>, (default_mat2x2.align - 1) * 7, false},
+        Params{ty_mat3x3<f32>, (default_mat3x3.align - 1) * 7, false},
+        Params{ty_mat4x4<f32>, (default_mat4x4.align - 1) * 7, false}
+
+        ));
+
+}  // namespace
+}  // namespace ArrayStrideTests
+}  // namespace resolver
 }  // namespace tint
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index dfcabea..5bbc820 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -1824,13 +1824,13 @@
     return nullptr;
   }
 
-  auto create_semantic = [&](uint32_t stride) -> semantic::Array* {
-    uint32_t el_align = 0;
-    uint32_t el_size = 0;
-    if (!DefaultAlignAndSize(arr->type(), el_align, el_size, source)) {
-      return nullptr;
-    }
+  uint32_t el_align = 0;
+  uint32_t el_size = 0;
+  if (!DefaultAlignAndSize(el_ty, el_align, el_size, source)) {
+    return nullptr;
+  }
 
+  auto create_semantic = [&](uint32_t stride) -> semantic::Array* {
     auto align = el_align;
     // WebGPU requires runtime arrays have at least one element, but the AST
     // records an element count of 0 for it.
@@ -1843,18 +1843,30 @@
   // Look for explicit stride via [[stride(n)]] decoration
   for (auto* deco : arr->decorations()) {
     if (auto* stride = deco->As<ast::StrideDecoration>()) {
-      return create_semantic(stride->stride());
+      auto explicit_stride = stride->stride();
+      bool is_valid_stride = (explicit_stride >= el_size) &&
+                             (explicit_stride >= el_align) &&
+                             (explicit_stride % el_align == 0);
+      if (!is_valid_stride) {
+        // https://gpuweb.github.io/gpuweb/wgsl/#array-layout-rules
+        // Arrays decorated with the stride attribute must have a stride that is
+        // at least the size of the element type, and be a multiple of the
+        // element type's alignment value.
+        diagnostics_.add_error(
+            "arrays decorated with the stride attribute must have a stride "
+            "that is at least the size of the element type, and be a multiple "
+            "of the element type's alignment value.",
+            source);
+        return nullptr;
+      }
+
+      return create_semantic(explicit_stride);
     }
   }
 
   // Calculate implicit stride
-  uint32_t el_align = 0;
-  uint32_t el_size = 0;
-  if (!DefaultAlignAndSize(el_ty, el_align, el_size, source)) {
-    return nullptr;
-  }
-
-  return create_semantic(utils::RoundUp(el_align, el_size));
+  auto implicit_stride = utils::RoundUp(el_align, el_size);
+  return create_semantic(implicit_stride);
 }
 
 bool Resolver::ValidateStructure(const type::Struct* st) {
diff --git a/src/resolver/resolver_test_helper.h b/src/resolver/resolver_test_helper.h
index cd9dd7a..a3dfce1 100644
--- a/src/resolver/resolver_test_helper.h
+++ b/src/resolver/resolver_test_helper.h
@@ -123,6 +123,17 @@
     type::Type* (*)(const ProgramBuilder::TypesBuilder& ty);
 
 template <typename T>
+type::Type* ty_vec2(const ProgramBuilder::TypesBuilder& ty) {
+  return ty.vec2<T>();
+}
+
+template <create_type_func_ptr create_type>
+type::Type* ty_vec2(const ProgramBuilder::TypesBuilder& ty) {
+  auto* type = create_type(ty);
+  return ty.vec2(type);
+}
+
+template <typename T>
 type::Type* ty_vec3(const ProgramBuilder::TypesBuilder& ty) {
   return ty.vec3<T>();
 }
@@ -134,6 +145,28 @@
 }
 
 template <typename T>
+type::Type* ty_vec4(const ProgramBuilder::TypesBuilder& ty) {
+  return ty.vec4<T>();
+}
+
+template <create_type_func_ptr create_type>
+type::Type* ty_vec4(const ProgramBuilder::TypesBuilder& ty) {
+  auto* type = create_type(ty);
+  return ty.vec4(type);
+}
+
+template <typename T>
+type::Type* ty_mat2x2(const ProgramBuilder::TypesBuilder& ty) {
+  return ty.mat2x2<T>();
+}
+
+template <create_type_func_ptr create_type>
+type::Type* ty_mat2x2(const ProgramBuilder::TypesBuilder& ty) {
+  auto* type = create_type(ty);
+  return ty.mat2x2(type);
+}
+
+template <typename T>
 type::Type* ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
   return ty.mat3x3<T>();
 }
@@ -144,6 +177,17 @@
   return ty.mat3x3(type);
 }
 
+template <typename T>
+type::Type* ty_mat4x4(const ProgramBuilder::TypesBuilder& ty) {
+  return ty.mat4x4<T>();
+}
+
+template <create_type_func_ptr create_type>
+type::Type* ty_mat4x4(const ProgramBuilder::TypesBuilder& ty) {
+  auto* type = create_type(ty);
+  return ty.mat4x4(type);
+}
+
 template <create_type_func_ptr create_type>
 type::Type* ty_alias(const ProgramBuilder::TypesBuilder& ty) {
   auto* type = create_type(ty);