sem: Fold together sem::Array and sem::ArrayType

There's now no need to have both.
Removes a whole bunch of Sem().Get() smell, and simplifies the resolver.

Also fixes a long-standing issue where an array with an explicit, but equal-to-implicit-stride attribute would result in a different type to an array without the decoration.

Bug: tint:724
Fixed: tint:782
Change-Id: I0202459009cd45be427cdb621993a5a3b07ff51e
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/50301
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 32273e2..a3162c2 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -178,8 +178,8 @@
   if (type->is_scalar() || type->Is<sem::Vector>() || type->Is<sem::Matrix>()) {
     return true;
   }
-  if (auto* arr = type->As<sem::ArrayType>()) {
-    return IsStorable(arr->type());
+  if (auto* arr = type->As<sem::Array>()) {
+    return IsStorable(arr->ElemType());
   }
   if (auto* str = type->As<sem::Struct>()) {
     for (const auto* member : str->Members()) {
@@ -204,8 +204,8 @@
   if (auto* mat = type->As<sem::Matrix>()) {
     return IsHostShareable(mat->type());
   }
-  if (auto* arr = type->As<sem::ArrayType>()) {
-    return IsHostShareable(arr->type());
+  if (auto* arr = type->As<sem::Array>()) {
+    return IsHostShareable(arr->ElemType());
   }
   if (auto* str = type->As<sem::Struct>()) {
     for (auto* member : str->Members()) {
@@ -287,7 +287,7 @@
         // TODO(crbug.com/tint/724) - Remove once tint:724 is complete.
         // ast::AccessDecorations are generated by the WGSL parser, used to
         // build sem::AccessControls and then leaked.
-        // ast::StrideDecoration are used to build a sem::ArrayTypes, but
+        // ast::StrideDecoration are used to build a sem::Arrays, but
         // multiple arrays of the same stride, size and element type are
         // currently de-duplicated by the type manager, and we leak these
         // decorations.
@@ -350,14 +350,7 @@
       return nullptr;
     }
     if (auto* t = ty->As<ast::Array>()) {
-      if (auto* el = Type(t->type())) {
-        auto* sem = builder_->create<sem::ArrayType>(
-            const_cast<sem::Type*>(el), t->size(), t->decorations());
-        if (Array(sem, ty->source())) {
-          return sem;
-        }
-      }
-      return nullptr;
+      return Array(t);
     }
     if (auto* t = ty->As<ast::Pointer>()) {
       if (auto* el = Type(t->type())) {
@@ -420,9 +413,10 @@
   return s;
 }
 
-Resolver::VariableInfo* Resolver::Variable(ast::Variable* var,
-                                           sem::Type* type, /* = nullptr */
-                                           std::string type_name /* = "" */) {
+Resolver::VariableInfo* Resolver::Variable(
+    ast::Variable* var,
+    const sem::Type* type, /* = nullptr */
+    std::string type_name /* = "" */) {
   auto it = variable_to_info_.find(var);
   if (it != variable_to_info_.end()) {
     return it->second;
@@ -436,18 +430,10 @@
     return nullptr;
   }
 
-  auto* ctype = Canonical(type);
+  auto* ctype = Canonical(const_cast<sem::Type*>(type));
   auto* info = variable_infos_.Create(var, ctype, type_name);
   variable_to_info_.emplace(var, info);
 
-  // TODO(bclayton): Why is this here? Needed?
-  // Resolve variable's type
-  if (auto* arr = info->type->As<sem::ArrayType>()) {
-    if (!Array(arr, var->source())) {
-      return nullptr;
-    }
-  }
-
   return info;
 }
 
@@ -596,8 +582,8 @@
 
 bool Resolver::ValidateVariable(const ast::Variable* var) {
   auto* type = variable_to_info_[var]->type;
-  if (auto* r = type->As<sem::ArrayType>()) {
-    if (r->IsRuntimeArray()) {
+  if (auto* r = type->As<sem::Array>()) {
+    if (r->IsRuntimeSized()) {
       diagnostics_.add_error(
           "v-0015",
           "runtime arrays may only appear as the last member of a struct",
@@ -873,8 +859,8 @@
                                     builder_->Symbols().NameFor(func->symbol()),
                                 func->source());
           return false;
-        } else if (auto* arr = member_ty->As<sem::ArrayType>()) {
-          if (arr->IsRuntimeArray()) {
+        } else if (auto* arr = member_ty->As<sem::Array>()) {
+          if (arr->IsRuntimeSized()) {
             diagnostics_.add_error(
                 "entry point IO types cannot contain runtime sized arrays",
                 member->Declaration()->source());
@@ -1276,9 +1262,9 @@
 
   auto* res = TypeOf(expr->array());
   auto* parent_type = res->UnwrapAll();
-  sem::Type* ret = nullptr;
-  if (auto* arr = parent_type->As<sem::ArrayType>()) {
-    ret = arr->type();
+  const sem::Type* ret = nullptr;
+  if (auto* arr = parent_type->As<sem::Array>()) {
+    ret = arr->ElemType();
   } else if (auto* vec = parent_type->As<sem::Vector>()) {
     ret = vec->type();
   } else if (auto* mat = parent_type->As<sem::Matrix>()) {
@@ -1293,8 +1279,8 @@
   // If we're extracting from a pointer, we return a pointer.
   if (auto* ptr = res->As<sem::Pointer>()) {
     ret = builder_->create<sem::Pointer>(ret, ptr->storage_class());
-  } else if (auto* arr = parent_type->As<sem::ArrayType>()) {
-    if (!arr->type()->is_scalar()) {
+  } else if (auto* arr = parent_type->As<sem::Array>()) {
+    if (!arr->ElemType()->is_scalar()) {
       // If we extract a non-scalar from an array then we also get a pointer. We
       // will generate a Function storage class variable to store this into.
       ret = builder_->create<sem::Pointer>(ret, ast::StorageClass::kFunction);
@@ -1459,7 +1445,7 @@
 
       value_cardinality_sum++;
     } else if (auto* value_vec = value_type->As<sem::Vector>()) {
-      sem::Type* value_elem_type = value_vec->type()->UnwrapAll();
+      auto* value_elem_type = value_vec->type()->UnwrapAll();
       // A mismatch of vector type parameter T is only an error if multiple
       // arguments are present. A single argument constructor constitutes a
       // type conversion expression.
@@ -1754,8 +1740,8 @@
   auto* lhs_declared_type = TypeOf(expr->lhs())->UnwrapAll();
   auto* rhs_declared_type = TypeOf(expr->rhs())->UnwrapAll();
 
-  auto* lhs_type = Canonical(lhs_declared_type);
-  auto* rhs_type = Canonical(rhs_declared_type);
+  auto* lhs_type = Canonical(const_cast<sem::Type*>(lhs_declared_type));
+  auto* rhs_type = Canonical(const_cast<sem::Type*>(rhs_declared_type));
 
   auto* lhs_vec = lhs_type->As<Vector>();
   auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
@@ -2006,7 +1992,7 @@
 
   // If the variable has a declared type, resolve it.
   std::string type_name;
-  sem::Type* type = nullptr;
+  const sem::Type* type = nullptr;
   if (auto* ast_ty = var->type()) {
     type_name = ast_ty->FriendlyName(builder_->Symbols());
     type = Type(ast_ty);
@@ -2065,7 +2051,7 @@
   }
   // TODO(bclayton): Remove this and fix tests. We're overriding the semantic
   // type stored in info->type here with a possibly non-canonicalized type.
-  info->type = type;
+  info->type = const_cast<sem::Type*>(type);
   variable_stack_.set(var->symbol(), info);
   current_block_->decls.push_back(var);
 
@@ -2251,8 +2237,7 @@
 
 bool Resolver::DefaultAlignAndSize(const sem::Type* ty,
                                    uint32_t& align,
-                                   uint32_t& size,
-                                   const Source& source) {
+                                   uint32_t& size) {
   static constexpr uint32_t vector_size[] = {
       /* padding */ 0,
       /* padding */ 0,
@@ -2297,76 +2282,71 @@
     align = s->Align();
     size = s->Size();
     return true;
-  } else if (cty->Is<sem::ArrayType>()) {
-    if (auto* sem =
-            Array(ty->UnwrapAliasIfNeeded()->As<sem::ArrayType>(), source)) {
-      align = sem->Align();
-      size = sem->Size();
-      return true;
-    }
-    return false;
+  } else if (auto* a = cty->As<sem::Array>()) {
+    align = a->Align();
+    size = a->SizeInBytes();
+    return true;
   }
   TINT_UNREACHABLE(diagnostics_) << "Invalid type " << ty->TypeInfo().name;
   return false;
 }
 
-const sem::Array* Resolver::Array(const sem::ArrayType* arr,
-                                  const Source& source) {
-  if (auto* sem = builder_->Sem().Get(arr)) {
-    // Semantic info already constructed for this array type
-    return sem;
-  }
+sem::Array* Resolver::Array(const ast::Array* arr) {
+  auto source = arr->source();
 
-  if (!ValidateArray(arr, source)) {
+  auto* el_ty = Type(arr->type());
+  if (!el_ty) {
     return nullptr;
   }
 
-  auto* el_ty = arr->type();
-
   uint32_t el_align = 0;
   uint32_t el_size = 0;
-  if (!DefaultAlignAndSize(el_ty, el_align, el_size, source)) {
+  if (!DefaultAlignAndSize(el_ty, el_align, el_size)) {
     return nullptr;
   }
 
-  auto create_semantic = [&](uint32_t stride) -> sem::Array* {
-    auto align = el_align;
-    // WebGPU requires runtime arrays have at least one element, but the AST
-    // records an element count of 0 for it.
-    auto size = std::max<uint32_t>(arr->size(), 1) * stride;
-    auto* sem = builder_->create<sem::Array>(const_cast<sem::ArrayType*>(arr),
-                                             align, size, stride);
-    builder_->Sem().Add(arr, sem);
-    return sem;
-  };
-
   // Look for explicit stride via [[stride(n)]] decoration
   uint32_t explicit_stride = 0;
   for (auto* deco : arr->decorations()) {
     Mark(deco);
-    if (auto* stride = deco->As<ast::StrideDecoration>()) {
+    if (auto* sd = deco->As<ast::StrideDecoration>()) {
       if (explicit_stride) {
         diagnostics_.add_error(
             "array must have at most one [[stride]] decoration", source);
         return nullptr;
       }
-      explicit_stride = stride->stride();
-      if (!ValidateArrayStrideDecoration(stride, el_size, el_align, source)) {
+      explicit_stride = sd->stride();
+      if (!ValidateArrayStrideDecoration(sd, el_size, el_align, source)) {
         return nullptr;
       }
+      continue;
     }
-  }
-  if (explicit_stride) {
-    return create_semantic(explicit_stride);
+
+    diagnostics_.add_error("decoration is not valid for array types",
+                           deco->source());
+    return nullptr;
   }
 
   // Calculate implicit stride
   auto implicit_stride = utils::RoundUp(el_align, el_size);
-  return create_semantic(implicit_stride);
+
+  auto stride = explicit_stride ? explicit_stride : implicit_stride;
+
+  // WebGPU requires runtime arrays have at least one element, but the AST
+  // records an element count of 0 for it.
+  auto size = std::max<uint32_t>(arr->size(), 1) * stride;
+  auto* sem = builder_->create<sem::Array>(el_ty, arr->size(), el_align, size,
+                                           stride, stride == implicit_stride);
+
+  if (!ValidateArray(sem, source)) {
+    return nullptr;
+  }
+
+  return sem;
 }
 
-bool Resolver::ValidateArray(const sem::ArrayType* arr, const Source& source) {
-  auto* el_ty = arr->type();
+bool Resolver::ValidateArray(const sem::Array* arr, const Source& source) {
+  auto* el_ty = arr->ElemType();
 
   if (!IsStorable(el_ty)) {
     builder_->Diagnostics().add_error(
@@ -2416,8 +2396,8 @@
 
 bool Resolver::ValidateStructure(const sem::Struct* str) {
   for (auto* member : str->Members()) {
-    if (auto* r = member->Type()->UnwrapAll()->As<sem::ArrayType>()) {
-      if (r->IsRuntimeArray()) {
+    if (auto* r = member->Type()->UnwrapAll()->As<sem::Array>()) {
+      if (r->IsRuntimeSized()) {
         if (member != str->Members().back()) {
           diagnostics_.add_error(
               "v-0015",
@@ -2434,14 +2414,6 @@
               member->Declaration()->source());
           return false;
         }
-
-        for (auto* deco : r->decorations()) {
-          if (!deco->Is<ast::StrideDecoration>()) {
-            diagnostics_.add_error("decoration is not valid for array types",
-                                   deco->source());
-            return false;
-          }
-        }
       }
     }
 
@@ -2511,7 +2483,7 @@
     uint32_t offset = struct_size;
     uint32_t align = 0;
     uint32_t size = 0;
-    if (!DefaultAlignAndSize(type, align, size, member->source())) {
+    if (!DefaultAlignAndSize(type, align, size)) {
       return nullptr;
     }
 
@@ -2779,7 +2751,7 @@
 bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
                                             sem::Type* ty,
                                             const Source& usage) {
-  ty = ty->UnwrapIfNeeded();
+  ty = const_cast<sem::Type*>(ty->UnwrapIfNeeded());
 
   if (auto* str = ty->As<sem::Struct>()) {
     if (str->StorageClassUsage().count(sc)) {
@@ -2801,8 +2773,9 @@
     return true;
   }
 
-  if (auto* arr = ty->As<sem::ArrayType>()) {
-    return ApplyStorageClassUsageToType(sc, arr->type(), usage);
+  if (auto* arr = ty->As<sem::Array>()) {
+    return ApplyStorageClassUsageToType(
+        sc, const_cast<sem::Type*>(arr->ElemType()), usage);
   }
 
   if (ast::IsHostShareable(sc) && !IsHostShareable(ty)) {
@@ -2829,7 +2802,8 @@
   return result;
 }
 
-std::string Resolver::VectorPretty(uint32_t size, sem::Type* element_type) {
+std::string Resolver::VectorPretty(uint32_t size,
+                                   const sem::Type* element_type) {
   sem::Vector vec_type(element_type, size);
   return vec_type.FriendlyName(builder_->Symbols());
 }