tint/resolver: Split up the Array resolving logic

In preparation for array constructors that can infer type and count
based on constructor arguments.

Bug: tint:1628
Change-Id: I9f12d7a30de232cf0d34ed7e1a356dd5b92d26d7
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/97587
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/resolver/attribute_validation_test.cc b/src/tint/resolver/attribute_validation_test.cc
index 2cb8905..36221e2 100644
--- a/src/tint/resolver/attribute_validation_test.cc
+++ b/src/tint/resolver/attribute_validation_test.cc
@@ -895,7 +895,8 @@
        << ", should_pass: " << params.should_pass;
     SCOPED_TRACE(ss.str());
 
-    auto* arr = ty.array(Source{{12, 34}}, el_ty, 4_u, params.stride);
+    auto* arr =
+        ty.array(el_ty, 4_u, {create<ast::StrideAttribute>(Source{{12, 34}}, params.stride)});
 
     GlobalVar("myarray", arr, ast::StorageClass::kPrivate);
 
@@ -906,7 +907,7 @@
         EXPECT_EQ(r()->error(),
                   "12:34 error: arrays decorated with the stride attribute must "
                   "have a stride that is at least the size of the element type, "
-                  "and be a multiple of the element type's alignment value.");
+                  "and be a multiple of the element type's alignment value");
     }
 }
 
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index b1961c3..7eb4d28 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -2259,102 +2259,128 @@
 }
 
 sem::Array* Resolver::Array(const ast::Array* arr) {
-    auto source = arr->source;
-
-    auto* elem_type = Type(arr->type);
-    if (!elem_type) {
+    if (!arr->type) {
+        AddError("missing array element type", arr->source.End());
         return nullptr;
     }
 
-    if (!validator_.IsPlain(elem_type)) {  // Check must come before GetDefaultAlignAndSize()
-        AddError(sem_.TypeNameOf(elem_type) + " cannot be used as an element type of an array",
-                 source);
-        return nullptr;
-    }
-
-    uint32_t el_align = elem_type->Align();
-    uint32_t el_size = elem_type->Size();
-
-    if (!validator_.NoDuplicateAttributes(arr->attributes)) {
+    auto* el_ty = Type(arr->type);
+    if (!el_ty) {
         return nullptr;
     }
 
     // Look for explicit stride via @stride(n) attribute
     uint32_t explicit_stride = 0;
-    for (auto* attr : arr->attributes) {
+    if (!ArrayAttributes(arr->attributes, el_ty, explicit_stride)) {
+        return nullptr;
+    }
+
+    uint32_t el_count = 0;  // sem::Array uses a size of 0 for a runtime-sized array.
+
+    // Evaluate the constant array size expression.
+    if (auto* count_expr = arr->count) {
+        if (auto count = ArrayCount(count_expr)) {
+            el_count = count.Get();
+        } else {
+            return nullptr;
+        }
+    }
+
+    auto* out = Array(arr->source, el_ty, el_count, explicit_stride);
+    if (out == nullptr) {
+        return nullptr;
+    }
+
+    if (el_ty->Is<sem::Atomic>()) {
+        atomic_composite_info_.emplace(out, arr->type->source);
+    } else {
+        auto found = atomic_composite_info_.find(el_ty);
+        if (found != atomic_composite_info_.end()) {
+            atomic_composite_info_.emplace(out, found->second);
+        }
+    }
+
+    return out;
+}
+
+utils::Result<uint32_t> Resolver::ArrayCount(const ast::Expression* count_expr) {
+    // Evaluate the constant array size expression.
+    const auto* count_sem = Materialize(Expression(count_expr));
+    if (!count_sem) {
+        return utils::Failure;
+    }
+
+    auto* count_val = count_sem->ConstantValue();
+    if (!count_val) {
+        AddError("array size must evaluate to a constant integer expression", count_expr->source);
+        return utils::Failure;
+    }
+
+    if (auto* ty = count_val->Type(); !ty->is_integer_scalar()) {
+        AddError("array size must evaluate to a constant integer expression, but is type '" +
+                     builder_->FriendlyName(ty) + "'",
+                 count_expr->source);
+        return utils::Failure;
+    }
+
+    int64_t count = count_val->As<AInt>();
+    if (count < 1) {
+        AddError("array size (" + std::to_string(count) + ") must be greater than 0",
+                 count_expr->source);
+        return utils::Failure;
+    }
+
+    return static_cast<uint32_t>(count);
+}
+
+bool Resolver::ArrayAttributes(const ast::AttributeList& attributes,
+                               const sem::Type* el_ty,
+                               uint32_t& explicit_stride) {
+    if (!validator_.NoDuplicateAttributes(attributes)) {
+        return false;
+    }
+
+    for (auto* attr : attributes) {
         Mark(attr);
         if (auto* sd = attr->As<ast::StrideAttribute>()) {
             explicit_stride = sd->stride;
-            if (!validator_.ArrayStrideAttribute(sd, el_size, el_align, source)) {
-                return nullptr;
+            if (!validator_.ArrayStrideAttribute(sd, el_ty->Size(), el_ty->Align())) {
+                return false;
             }
             continue;
         }
 
         AddError("attribute is not valid for array types", attr->source);
-        return nullptr;
+        return false;
     }
 
-    // Calculate implicit stride
-    uint64_t implicit_stride = utils::RoundUp<uint64_t>(el_align, el_size);
+    return true;
+}
 
+sem::Array* Resolver::Array(const Source& source,
+                            const sem::Type* el_ty,
+                            uint32_t el_count,
+                            uint32_t explicit_stride) {
+    uint32_t el_align = el_ty->Align();
+    uint32_t el_size = el_ty->Size();
+    uint64_t implicit_stride = el_size ? utils::RoundUp<uint64_t>(el_align, el_size) : 0;
     uint64_t stride = explicit_stride ? explicit_stride : implicit_stride;
 
-    int64_t count = 0;  // sem::Array uses a size of 0 for a runtime-sized array.
-
-    // Evaluate the constant array size expression.
-    if (auto* count_expr = arr->count) {
-        const auto* count_sem = Materialize(Expression(count_expr));
-        if (!count_sem) {
-            return nullptr;
-        }
-
-        auto* count_val = count_sem->ConstantValue();
-        if (!count_val) {
-            AddError("array size must evaluate to a constant integer expression",
-                     count_expr->source);
-            return nullptr;
-        }
-
-        if (auto* ty = count_val->Type(); !ty->is_integer_scalar()) {
-            AddError("array size must evaluate to a constant integer expression, but is type '" +
-                         builder_->FriendlyName(ty) + "'",
-                     count_expr->source);
-            return nullptr;
-        }
-
-        count = count_val->As<AInt>();
-        if (count < 1) {
-            AddError("array size (" + std::to_string(count) + ") must be greater than 0",
-                     count_expr->source);
-            return nullptr;
-        }
-    }
-
-    auto size = std::max<uint64_t>(static_cast<uint32_t>(count), 1u) * stride;
+    auto size = std::max<uint64_t>(el_count, 1u) * stride;
     if (size > std::numeric_limits<uint32_t>::max()) {
         std::stringstream msg;
         msg << "array size (0x" << std::hex << size << ") must not exceed 0xffffffff bytes";
-        AddError(msg.str(), arr->source);
+        AddError(msg.str(), source);
         return nullptr;
     }
-    auto* out = builder_->create<sem::Array>(
-        elem_type, static_cast<uint32_t>(count), el_align, static_cast<uint32_t>(size),
-        static_cast<uint32_t>(stride), static_cast<uint32_t>(implicit_stride));
+    auto* out = builder_->create<sem::Array>(el_ty, el_count, el_align, static_cast<uint32_t>(size),
+                                             static_cast<uint32_t>(stride),
+                                             static_cast<uint32_t>(implicit_stride));
 
     if (!validator_.Array(out, source)) {
         return nullptr;
     }
 
-    if (elem_type->Is<sem::Atomic>()) {
-        atomic_composite_info_.emplace(out, arr->type->source);
-    } else {
-        auto found = atomic_composite_info_.find(elem_type);
-        if (found != atomic_composite_info_.end()) {
-            atomic_composite_info_.emplace(out, found->second);
-        }
-    }
-
     return out;
 }
 
diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h
index ccd3895..5d38fef 100644
--- a/src/tint/resolver/resolver.h
+++ b/src/tint/resolver/resolver.h
@@ -286,14 +286,38 @@
     /// @returns the resolved semantic type
     sem::Type* TypeDecl(const ast::TypeDecl* named_type);
 
-    /// Builds and returns the semantic information for the array `arr`.
-    /// This method does not mark the ast::Array node, nor attach the generated
-    /// semantic information to the AST node.
-    /// @returns the semantic Array information, or nullptr if an error is
-    /// raised.
+    /// Builds and returns the semantic information for the AST array `arr`.
+    /// This method does not mark the ast::Array node, nor attach the generated semantic information
+    /// to the AST node.
+    /// @returns the semantic Array information, or nullptr if an error is raised.
     /// @param arr the Array to get semantic information for
     sem::Array* Array(const ast::Array* arr);
 
+    /// Resolves and validates the expression used as the count parameter of an array.
+    /// @param count_expr the expression used as the second template parameter to an array<>.
+    /// @returns the number of elements in the array.
+    utils::Result<uint32_t> ArrayCount(const ast::Expression* count_expr);
+
+    /// Resolves and validates the attributes on an array.
+    /// @param attributes the attributes on the array type.
+    /// @param el_ty the element type of the array.
+    /// @param explicit_stride assigned the specified stride of the array in bytes.
+    /// @returns true on success, false on failure
+    bool ArrayAttributes(const ast::AttributeList& attributes,
+                         const sem::Type* el_ty,
+                         uint32_t& explicit_stride);
+
+    /// Builds and returns the semantic information for an array.
+    /// @returns the semantic Array information, or nullptr if an error is raised.
+    /// @param source the source of the array declaration
+    /// @param el_ty the Array element type
+    /// @param el_count the number of elements in the array. Zero means runtime-sized.
+    /// @param explicit_stride the explicit byte stride of the array. Zero means implicit stride.
+    sem::Array* Array(const Source& source,
+                      const sem::Type* el_ty,
+                      uint32_t el_count,
+                      uint32_t explicit_stride);
+
     /// Builds and returns the semantic information for the alias `alias`.
     /// This method does not mark the ast::Alias node, nor attach the generated
     /// semantic information to the AST node.
diff --git a/src/tint/resolver/type_constructor_validation_test.cc b/src/tint/resolver/type_constructor_validation_test.cc
index 0fb01bc..aa0d65a 100644
--- a/src/tint/resolver/type_constructor_validation_test.cc
+++ b/src/tint/resolver/type_constructor_validation_test.cc
@@ -527,9 +527,7 @@
     WrapInFunction(tc);
 
     EXPECT_FALSE(r()->Resolve());
-    EXPECT_EQ(r()->error(),
-              "12:34 error: type in array constructor does not match array type: "
-              "expected 'u32', found 'f32'");
+    EXPECT_EQ(r()->error(), R"(12:34 error: 'f32' cannot be used to construct an array of 'u32')");
 }
 
 TEST_F(ResolverTypeConstructorValidationTest,
@@ -539,9 +537,7 @@
     WrapInFunction(tc);
 
     EXPECT_FALSE(r()->Resolve());
-    EXPECT_EQ(r()->error(),
-              "12:34 error: type in array constructor does not match array type: "
-              "expected 'f32', found 'i32'");
+    EXPECT_EQ(r()->error(), R"(12:34 error: 'i32' cannot be used to construct an array of 'f32')");
 }
 
 TEST_F(ResolverTypeConstructorValidationTest,
@@ -552,9 +548,7 @@
     WrapInFunction(tc);
 
     EXPECT_FALSE(r()->Resolve());
-    EXPECT_EQ(r()->error(),
-              "12:34 error: type in array constructor does not match array type: "
-              "expected 'u32', found 'i32'");
+    EXPECT_EQ(r()->error(), R"(12:34 error: 'i32' cannot be used to construct an array of 'u32')");
 }
 
 TEST_F(ResolverTypeConstructorValidationTest,
@@ -564,8 +558,7 @@
     WrapInFunction(tc);
     EXPECT_FALSE(r()->Resolve());
     EXPECT_EQ(r()->error(),
-              "12:34 error: type in array constructor does not match array type: "
-              "expected 'i32', found 'vec2<i32>'");
+              R"(12:34 error: 'vec2<i32>' cannot be used to construct an array of 'i32')");
 }
 
 TEST_F(ResolverTypeConstructorValidationTest,
@@ -579,8 +572,7 @@
 
     EXPECT_FALSE(r()->Resolve());
     EXPECT_EQ(r()->error(),
-              "12:34 error: type in array constructor does not match array type: "
-              "expected 'vec3<i32>', found 'vec3<u32>'");
+              R"(12:34 error: 'vec3<u32>' cannot be used to construct an array of 'vec3<i32>')");
 }
 
 TEST_F(ResolverTypeConstructorValidationTest,
@@ -594,8 +586,7 @@
 
     EXPECT_FALSE(r()->Resolve());
     EXPECT_EQ(r()->error(),
-              "12:34 error: type in array constructor does not match array type: "
-              "expected 'vec3<i32>', found 'vec3<bool>'");
+              R"(12:34 error: 'vec3<bool>' cannot be used to construct an array of 'vec3<i32>')");
 }
 
 TEST_F(ResolverTypeConstructorValidationTest, Expr_Constructor_ArrayOfArray_SubElemSizeMismatch) {
@@ -607,9 +598,9 @@
     WrapInFunction(t);
 
     EXPECT_FALSE(r()->Resolve());
-    EXPECT_EQ(r()->error(),
-              "12:34 error: type in array constructor does not match array type: "
-              "expected 'array<i32, 2>', found 'array<i32, 3>'");
+    EXPECT_EQ(
+        r()->error(),
+        R"(12:34 error: 'array<i32, 3>' cannot be used to construct an array of 'array<i32, 2>')");
 }
 
 TEST_F(ResolverTypeConstructorValidationTest, Expr_Constructor_ArrayOfArray_SubElemTypeMismatch) {
@@ -621,9 +612,9 @@
     WrapInFunction(t);
 
     EXPECT_FALSE(r()->Resolve());
-    EXPECT_EQ(r()->error(),
-              "12:34 error: type in array constructor does not match array type: "
-              "expected 'array<i32, 2>', found 'array<u32, 2>'");
+    EXPECT_EQ(
+        r()->error(),
+        R"(12:34 error: 'array<u32, 2>' cannot be used to construct an array of 'array<i32, 2>')");
 }
 
 TEST_F(ResolverTypeConstructorValidationTest, Expr_Constructor_Array_TooFewElements) {
@@ -656,7 +647,7 @@
     WrapInFunction(tc);
 
     EXPECT_FALSE(r()->Resolve());
-    EXPECT_EQ(r()->error(), "error: cannot init a runtime-sized array");
+    EXPECT_EQ(r()->error(), "error: cannot construct a runtime-sized array");
 }
 
 TEST_F(ResolverTypeConstructorValidationTest, Expr_Constructor_Array_RuntimeZeroValue) {
@@ -665,7 +656,7 @@
     WrapInFunction(tc);
 
     EXPECT_FALSE(r()->Resolve());
-    EXPECT_EQ(r()->error(), "error: cannot init a runtime-sized array");
+    EXPECT_EQ(r()->error(), "error: cannot construct a runtime-sized array");
 }
 
 }  // namespace ArrayConstructor
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc
index 940e154..d854447 100644
--- a/src/tint/resolver/validator.cc
+++ b/src/tint/resolver/validator.cc
@@ -1875,17 +1875,16 @@
     for (auto* value : values) {
         auto* value_ty = sem_.TypeOf(value)->UnwrapRef();
         if (value_ty != elem_ty) {
-            AddError(
-                "type in array constructor does not match array type: "
-                "expected '" +
-                    sem_.TypeNameOf(elem_ty) + "', found '" + sem_.TypeNameOf(value_ty) + "'",
-                value->source);
+            AddError("'" + sem_.TypeNameOf(value_ty) +
+                         "' cannot be used to construct an array of '" + sem_.TypeNameOf(elem_ty) +
+                         "'",
+                     value->source);
             return false;
         }
     }
 
     if (array_type->IsRuntimeSized()) {
-        AddError("cannot init a runtime-sized array", ctor->source);
+        AddError("cannot construct a runtime-sized array", ctor->source);
         return false;
     } else if (!elem_ty->IsConstructible()) {
         AddError("array constructor has non-constructible element type", ctor->source);
@@ -2011,6 +2010,11 @@
 bool Validator::Array(const sem::Array* arr, const Source& source) const {
     auto* el_ty = arr->ElemType();
 
+    if (!IsPlain(el_ty)) {
+        AddError(sem_.TypeNameOf(el_ty) + " cannot be used as an element type of an array", source);
+        return false;
+    }
+
     if (!IsFixedFootprint(el_ty)) {
         AddError("an array element type cannot contain a runtime-sized array", source);
         return false;
@@ -2020,8 +2024,7 @@
 
 bool Validator::ArrayStrideAttribute(const ast::StrideAttribute* attr,
                                      uint32_t el_size,
-                                     uint32_t el_align,
-                                     const Source& source) const {
+                                     uint32_t el_align) const {
     auto stride = attr->stride;
     bool is_valid_stride = (stride >= el_size) && (stride >= el_align) && (stride % el_align == 0);
     if (!is_valid_stride) {
@@ -2032,8 +2035,8 @@
         AddError(
             "arrays decorated with the stride attribute must have a stride "
             "that is at least the size of the element type, and be a multiple "
-            "of the element type's alignment value.",
-            source);
+            "of the element type's alignment value",
+            attr->source);
         return false;
     }
     return true;
diff --git a/src/tint/resolver/validator.h b/src/tint/resolver/validator.h
index e9f9cea..385e020 100644
--- a/src/tint/resolver/validator.h
+++ b/src/tint/resolver/validator.h
@@ -131,12 +131,10 @@
     /// @param attr the stride attribute to validate
     /// @param el_size the element size
     /// @param el_align the element alignment
-    /// @param source the source of the attribute
     /// @returns true on success, false otherwise
     bool ArrayStrideAttribute(const ast::StrideAttribute* attr,
                               uint32_t el_size,
-                              uint32_t el_align,
-                              const Source& source) const;
+                              uint32_t el_align) const;
 
     /// Validates an atomic
     /// @param a the atomic ast node to validate