Resolver: Validate that type sizes fit in uint32_t

Bug: chromium:1249708
Bug: tint:1177
Change-Id: I31c52f160e4952475e977453206ab4224fd20df7
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/64340
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 84d0963..d2948b6 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -17,6 +17,7 @@
 #include <algorithm>
 #include <cmath>
 #include <iomanip>
+#include <limits>
 #include <utility>
 
 #include "src/ast/alias.h"
@@ -3864,15 +3865,30 @@
   }
 
   // Calculate implicit stride
-  auto implicit_stride = utils::RoundUp(el_align, el_size);
+  uint64_t implicit_stride = utils::RoundUp<uint64_t>(el_align, el_size);
 
-  auto stride = explicit_stride ? explicit_stride : implicit_stride;
+  uint64_t stride = explicit_stride ? explicit_stride : implicit_stride;
 
   // WebGPU requires runtime arrays have at least one element, but the AST
   // records an element count of 0 for it.
   auto size = std::max<uint32_t>(arr->size(), 1) * stride;
-  auto* out = builder_->create<sem::Array>(elem_type, arr->size(), el_align,
-                                           size, stride, implicit_stride);
+  if (size > std::numeric_limits<uint32_t>::max()) {
+    std::stringstream msg;
+    msg << "array size in bytes must not exceed 0x" << std::hex
+        << std::numeric_limits<uint32_t>::max() << ", but is 0x" << std::hex
+        << size;
+    AddError(msg.str(), arr->source());
+    return nullptr;
+  }
+  if (stride > std::numeric_limits<uint32_t>::max() ||
+      implicit_stride > std::numeric_limits<uint32_t>::max()) {
+    TINT_ICE(Resolver, diagnostics_)
+        << "calculated array stride exceeds uint32";
+    return nullptr;
+  }
+  auto* out = builder_->create<sem::Array>(
+      elem_type, arr->size(), el_align, static_cast<uint32_t>(size),
+      static_cast<uint32_t>(stride), static_cast<uint32_t>(implicit_stride));
 
   if (!ValidateArray(out, source)) {
     return nullptr;
@@ -4091,8 +4107,8 @@
   // Validation of storage-class rules requires analysing the actual variable
   // usage of the structure, and so is performed as part of the variable
   // validation.
-  uint32_t struct_size = 0;
-  uint32_t struct_align = 1;
+  uint64_t struct_size = 0;
+  uint64_t struct_align = 1;
   std::unordered_map<Symbol, ast::StructMember*> member_map;
 
   for (auto* member : str->members()) {
@@ -4120,9 +4136,9 @@
       return nullptr;
     }
 
-    uint32_t offset = struct_size;
-    uint32_t align = type->Align();
-    uint32_t size = type->Size();
+    uint64_t offset = struct_size;
+    uint64_t align = type->Align();
+    uint64_t size = type->Size();
 
     if (!ValidateNoDuplicateDecorations(member->decorations())) {
       return nullptr;
@@ -4171,6 +4187,14 @@
     }
 
     offset = utils::RoundUp(align, offset);
+    if (offset > std::numeric_limits<uint32_t>::max()) {
+      std::stringstream msg;
+      msg << "struct member has byte offset 0x" << std::hex << offset
+          << ", but must not exceed 0x" << std::hex
+          << std::numeric_limits<uint32_t>::max();
+      AddError(msg.str(), member->source());
+      return nullptr;
+    }
 
     auto* sem_member = builder_->create<sem::StructMember>(
         member, member->symbol(), const_cast<sem::Type*>(type),
@@ -4182,12 +4206,27 @@
     struct_align = std::max(struct_align, align);
   }
 
-  auto size_no_padding = struct_size;
+  uint64_t size_no_padding = struct_size;
   struct_size = utils::RoundUp(struct_align, struct_size);
 
-  auto* out =
-      builder_->create<sem::Struct>(str, str->name(), sem_members, struct_align,
-                                    struct_size, size_no_padding);
+  if (struct_size > std::numeric_limits<uint32_t>::max()) {
+    std::stringstream msg;
+    msg << "struct size in bytes must not exceed 0x" << std::hex
+        << std::numeric_limits<uint32_t>::max() << ", but is 0x" << std::hex
+        << struct_size;
+    AddError(msg.str(), str->source());
+    return nullptr;
+  }
+  if (struct_align > std::numeric_limits<uint32_t>::max()) {
+    TINT_ICE(Resolver, diagnostics_)
+        << "calculated struct stride exceeds uint32";
+    return nullptr;
+  }
+
+  auto* out = builder_->create<sem::Struct>(
+      str, str->name(), sem_members, static_cast<uint32_t>(struct_align),
+      static_cast<uint32_t>(struct_size),
+      static_cast<uint32_t>(size_no_padding));
 
   for (size_t i = 0; i < sem_members.size(); i++) {
     auto* mem_type = sem_members[i]->Type();
diff --git a/src/resolver/type_validation_test.cc b/src/resolver/type_validation_test.cc
index 52925b4..5a30264 100644
--- a/src/resolver/type_validation_test.cc
+++ b/src/resolver/type_validation_test.cc
@@ -201,6 +201,26 @@
   EXPECT_TRUE(r()->Resolve()) << r()->error();
 }
 
+TEST_F(ResolverTypeValidationTest, ArraySize_TooBig_ImplicitStride) {
+  // var<private> a : array<f32, 0x40000000>;
+  Global("a", ty.array(Source{{12, 34}}, ty.f32(), 0x40000000),
+         ast::StorageClass::kPrivate);
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: array size in bytes must not exceed 0xffffffff, but "
+            "is 0x100000000");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_TooBig_ExplicitStride) {
+  // var<private> a : [[stride(8)]] array<f32, 0x20000000>;
+  Global("a", ty.array(Source{{12, 34}}, ty.f32(), 0x20000000, 8),
+         ast::StorageClass::kPrivate);
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: array size in bytes must not exceed 0xffffffff, but "
+            "is 0x100000000");
+}
+
 TEST_F(ResolverTypeValidationTest, RuntimeArrayInFunction_Fail) {
   /// [[stage(vertex)]]
   // fn func() { var a : array<i32>; }
@@ -222,8 +242,49 @@
             "a struct");
 }
 
+TEST_F(ResolverTypeValidationTest, Struct_TooBig) {
+  // struct Foo {
+  //   a: array<f32, 0x20000000>;
+  //   b: array<f32, 0x20000000>;
+  // };
+
+  Structure(Source{{12, 34}}, "Foo",
+            {
+                Member("a", ty.array<f32, 0x20000000>()),
+                Member("b", ty.array<f32, 0x20000000>()),
+            });
+
+  WrapInFunction();
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: struct size in bytes must not exceed 0xffffffff, but "
+            "is 0x100000000");
+}
+
+TEST_F(ResolverTypeValidationTest, Struct_MemberOffset_TooBig) {
+  // struct Foo {
+  //   a: array<f32, 0x3fffffff>;
+  //   b: f32;
+  //   c: f32;
+  // };
+
+  Structure("Foo", {
+                       Member("a", ty.array<f32, 0x3fffffff>()),
+                       Member("b", ty.f32()),
+                       Member(Source{{12, 34}}, "c", ty.f32()),
+                   });
+
+  WrapInFunction();
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: struct member has byte offset 0x100000000, but must "
+            "not exceed 0xffffffff");
+}
+
 TEST_F(ResolverTypeValidationTest, RuntimeArrayIsLast_Pass) {
-  // [[Block]]
+  // [[block]]
   // struct Foo {
   //   vf: f32;
   //   rt: array<f32>;
@@ -261,7 +322,7 @@
 }
 
 TEST_F(ResolverTypeValidationTest, RuntimeArrayIsNotLast_Fail) {
-  // [[Block]]
+  // [[block]]
   // struct Foo {
   //   rt: array<f32>;
   //   vf: f32;
@@ -330,7 +391,7 @@
 }
 
 TEST_F(ResolverTypeValidationTest, AliasRuntimeArrayIsNotLast_Fail) {
-  // [[Block]]
+  // [[block]]
   // type RTArr = array<u32>;
   // struct s {
   //  b: RTArr;
@@ -354,7 +415,7 @@
 }
 
 TEST_F(ResolverTypeValidationTest, AliasRuntimeArrayIsLast_Pass) {
-  // [[Block]]
+  // [[block]]
   // type RTArr = array<u32>;
   // struct s {
   //  a: u32;
diff --git a/src/transform/robustness_test.cc b/src/transform/robustness_test.cc
index 2e5b96b..3f90787 100644
--- a/src/transform/robustness_test.cc
+++ b/src/transform/robustness_test.cc
@@ -170,7 +170,9 @@
   EXPECT_EQ(expect, str(got));
 }
 
-TEST_F(RobustnessTest, LargeArrays_Idx) {
+// TODO(crbug.com/tint/1177) - Validation currently forbids arrays larger than
+// 0xffffffff. If WGSL supports 64-bit indexing, re-enable this test.
+TEST_F(RobustnessTest, DISABLED_LargeArrays_Idx) {
   auto* src = R"(
 [[block]]
 struct S {