Resolver: Validate that type sizes fit in uint32_t

Bug: chromium:1249708
Bug: tint:1177
Change-Id: I31c52f160e4952475e977453206ab4224fd20df7
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/64320
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 1b26b60..0b8c48d 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -17,6 +17,7 @@
 #include <algorithm>
 #include <cmath>
 #include <iomanip>
+#include <limits>
 #include <utility>
 
 #include "src/ast/alias.h"
@@ -3870,9 +3871,9 @@
   }
 
   // Calculate implicit stride
-  auto implicit_stride = utils::RoundUp(el_align, el_size);
+  uint64_t implicit_stride = utils::RoundUp<uint64_t>(el_align, el_size);
 
-  auto stride = explicit_stride ? explicit_stride : implicit_stride;
+  uint64_t stride = explicit_stride ? explicit_stride : implicit_stride;
 
   // Evaluate the constant array size expression.
   // sem::Array uses a size of 0 for a runtime-sized array.
@@ -3933,9 +3934,24 @@
     count = count_val.Elements()[0].u32;
   }
 
-  auto size = std::max<uint32_t>(count, 1) * stride;
-  auto* out = builder_->create<sem::Array>(elem_type, count, el_align, size,
-                                           stride, implicit_stride);
+  auto size = std::max<uint64_t>(count, 1) * stride;
+  if (size > std::numeric_limits<uint32_t>::max()) {
+    std::stringstream msg;
+    msg << "array size in bytes must not exceed 0x" << std::hex
+        << std::numeric_limits<uint32_t>::max() << ", but is 0x" << std::hex
+        << size;
+    AddError(msg.str(), arr->source());
+    return nullptr;
+  }
+  if (stride > std::numeric_limits<uint32_t>::max() ||
+      implicit_stride > std::numeric_limits<uint32_t>::max()) {
+    TINT_ICE(Resolver, diagnostics_)
+        << "calculated array stride exceeds uint32";
+    return nullptr;
+  }
+  auto* out = builder_->create<sem::Array>(
+      elem_type, count, el_align, static_cast<uint32_t>(size),
+      static_cast<uint32_t>(stride), static_cast<uint32_t>(implicit_stride));
 
   if (!ValidateArray(out, source)) {
     return nullptr;
@@ -4154,8 +4170,8 @@
   // Validation of storage-class rules requires analysing the actual variable
   // usage of the structure, and so is performed as part of the variable
   // validation.
-  uint32_t struct_size = 0;
-  uint32_t struct_align = 1;
+  uint64_t struct_size = 0;
+  uint64_t struct_align = 1;
   std::unordered_map<Symbol, ast::StructMember*> member_map;
 
   for (auto* member : str->members()) {
@@ -4183,9 +4199,9 @@
       return nullptr;
     }
 
-    uint32_t offset = struct_size;
-    uint32_t align = type->Align();
-    uint32_t size = type->Size();
+    uint64_t offset = struct_size;
+    uint64_t align = type->Align();
+    uint64_t size = type->Size();
 
     if (!ValidateNoDuplicateDecorations(member->decorations())) {
       return nullptr;
@@ -4234,6 +4250,14 @@
     }
 
     offset = utils::RoundUp(align, offset);
+    if (offset > std::numeric_limits<uint32_t>::max()) {
+      std::stringstream msg;
+      msg << "struct member has byte offset 0x" << std::hex << offset
+          << ", but must not exceed 0x" << std::hex
+          << std::numeric_limits<uint32_t>::max();
+      AddError(msg.str(), member->source());
+      return nullptr;
+    }
 
     auto* sem_member = builder_->create<sem::StructMember>(
         member, member->symbol(), const_cast<sem::Type*>(type),
@@ -4245,12 +4269,27 @@
     struct_align = std::max(struct_align, align);
   }
 
-  auto size_no_padding = struct_size;
+  uint64_t size_no_padding = struct_size;
   struct_size = utils::RoundUp(struct_align, struct_size);
 
-  auto* out =
-      builder_->create<sem::Struct>(str, str->name(), sem_members, struct_align,
-                                    struct_size, size_no_padding);
+  if (struct_size > std::numeric_limits<uint32_t>::max()) {
+    std::stringstream msg;
+    msg << "struct size in bytes must not exceed 0x" << std::hex
+        << std::numeric_limits<uint32_t>::max() << ", but is 0x" << std::hex
+        << struct_size;
+    AddError(msg.str(), str->source());
+    return nullptr;
+  }
+  if (struct_align > std::numeric_limits<uint32_t>::max()) {
+    TINT_ICE(Resolver, diagnostics_)
+        << "calculated struct stride exceeds uint32";
+    return nullptr;
+  }
+
+  auto* out = builder_->create<sem::Struct>(
+      str, str->name(), sem_members, static_cast<uint32_t>(struct_align),
+      static_cast<uint32_t>(struct_size),
+      static_cast<uint32_t>(size_no_padding));
 
   for (size_t i = 0; i < sem_members.size(); i++) {
     auto* mem_type = sem_members[i]->Type();
diff --git a/src/resolver/type_validation_test.cc b/src/resolver/type_validation_test.cc
index 549b814..354978f 100644
--- a/src/resolver/type_validation_test.cc
+++ b/src/resolver/type_validation_test.cc
@@ -325,6 +325,26 @@
   EXPECT_EQ(r()->error(), "12:34 error: array size must be integer scalar");
 }
 
+TEST_F(ResolverTypeValidationTest, ArraySize_TooBig_ImplicitStride) {
+  // var<private> a : array<f32, 0x40000000>;
+  Global("a", ty.array(Source{{12, 34}}, ty.f32(), 0x40000000),
+         ast::StorageClass::kPrivate);
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: array size in bytes must not exceed 0xffffffff, but "
+            "is 0x100000000");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_TooBig_ExplicitStride) {
+  // var<private> a : [[stride(8)]] array<f32, 0x20000000>;
+  Global("a", ty.array(Source{{12, 34}}, ty.f32(), 0x20000000, 8),
+         ast::StorageClass::kPrivate);
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: array size in bytes must not exceed 0xffffffff, but "
+            "is 0x100000000");
+}
+
 TEST_F(ResolverTypeValidationTest, ArraySize_OverridableConstant) {
   // [[override]] let size = 10;
   // var<private> a : array<f32, size>;
@@ -396,8 +416,49 @@
             "a struct");
 }
 
+TEST_F(ResolverTypeValidationTest, Struct_TooBig) {
+  // struct Foo {
+  //   a: array<f32, 0x20000000>;
+  //   b: array<f32, 0x20000000>;
+  // };
+
+  Structure(Source{{12, 34}}, "Foo",
+            {
+                Member("a", ty.array<f32, 0x20000000>()),
+                Member("b", ty.array<f32, 0x20000000>()),
+            });
+
+  WrapInFunction();
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: struct size in bytes must not exceed 0xffffffff, but "
+            "is 0x100000000");
+}
+
+TEST_F(ResolverTypeValidationTest, Struct_MemberOffset_TooBig) {
+  // struct Foo {
+  //   a: array<f32, 0x3fffffff>;
+  //   b: f32;
+  //   c: f32;
+  // };
+
+  Structure("Foo", {
+                       Member("a", ty.array<f32, 0x3fffffff>()),
+                       Member("b", ty.f32()),
+                       Member(Source{{12, 34}}, "c", ty.f32()),
+                   });
+
+  WrapInFunction();
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: struct member has byte offset 0x100000000, but must "
+            "not exceed 0xffffffff");
+}
+
 TEST_F(ResolverTypeValidationTest, RuntimeArrayIsLast_Pass) {
-  // [[Block]]
+  // [[block]]
   // struct Foo {
   //   vf: f32;
   //   rt: array<f32>;
@@ -435,7 +496,7 @@
 }
 
 TEST_F(ResolverTypeValidationTest, RuntimeArrayIsNotLast_Fail) {
-  // [[Block]]
+  // [[block]]
   // struct Foo {
   //   rt: array<f32>;
   //   vf: f32;
@@ -504,7 +565,7 @@
 }
 
 TEST_F(ResolverTypeValidationTest, AliasRuntimeArrayIsNotLast_Fail) {
-  // [[Block]]
+  // [[block]]
   // type RTArr = array<u32>;
   // struct s {
   //  b: RTArr;
@@ -528,7 +589,7 @@
 }
 
 TEST_F(ResolverTypeValidationTest, AliasRuntimeArrayIsLast_Pass) {
-  // [[Block]]
+  // [[block]]
   // type RTArr = array<u32>;
   // struct s {
   //  a: u32;
diff --git a/src/transform/robustness_test.cc b/src/transform/robustness_test.cc
index 2e5b96b..3f90787 100644
--- a/src/transform/robustness_test.cc
+++ b/src/transform/robustness_test.cc
@@ -170,7 +170,9 @@
   EXPECT_EQ(expect, str(got));
 }
 
-TEST_F(RobustnessTest, LargeArrays_Idx) {
+// TODO(crbug.com/tint/1177) - Validation currently forbids arrays larger than
+// 0xffffffff. If WGSL supports 64-bit indexing, re-enable this test.
+TEST_F(RobustnessTest, DISABLED_LargeArrays_Idx) {
   auto* src = R"(
 [[block]]
 struct S {