tint/resolver: Split up the Array resolving logic
In preparation for array constructors that can infer type and count
based on constructor arguments.
Bug: tint:1628
Change-Id: I9f12d7a30de232cf0d34ed7e1a356dd5b92d26d7
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/97587
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index b1961c3..7eb4d28 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -2259,102 +2259,128 @@
}
sem::Array* Resolver::Array(const ast::Array* arr) {
- auto source = arr->source;
-
- auto* elem_type = Type(arr->type);
- if (!elem_type) {
+ if (!arr->type) {
+ AddError("missing array element type", arr->source.End());
return nullptr;
}
- if (!validator_.IsPlain(elem_type)) { // Check must come before GetDefaultAlignAndSize()
- AddError(sem_.TypeNameOf(elem_type) + " cannot be used as an element type of an array",
- source);
- return nullptr;
- }
-
- uint32_t el_align = elem_type->Align();
- uint32_t el_size = elem_type->Size();
-
- if (!validator_.NoDuplicateAttributes(arr->attributes)) {
+ auto* el_ty = Type(arr->type);
+ if (!el_ty) {
return nullptr;
}
// Look for explicit stride via @stride(n) attribute
uint32_t explicit_stride = 0;
- for (auto* attr : arr->attributes) {
+ if (!ArrayAttributes(arr->attributes, el_ty, explicit_stride)) {
+ return nullptr;
+ }
+
+ uint32_t el_count = 0; // sem::Array uses a size of 0 for a runtime-sized array.
+
+ // Evaluate the constant array size expression.
+ if (auto* count_expr = arr->count) {
+ if (auto count = ArrayCount(count_expr)) {
+ el_count = count.Get();
+ } else {
+ return nullptr;
+ }
+ }
+
+ auto* out = Array(arr->source, el_ty, el_count, explicit_stride);
+ if (out == nullptr) {
+ return nullptr;
+ }
+
+ if (el_ty->Is<sem::Atomic>()) {
+ atomic_composite_info_.emplace(out, arr->type->source);
+ } else {
+ auto found = atomic_composite_info_.find(el_ty);
+ if (found != atomic_composite_info_.end()) {
+ atomic_composite_info_.emplace(out, found->second);
+ }
+ }
+
+ return out;
+}
+
+utils::Result<uint32_t> Resolver::ArrayCount(const ast::Expression* count_expr) {
+ // Evaluate the constant array size expression.
+ const auto* count_sem = Materialize(Expression(count_expr));
+ if (!count_sem) {
+ return utils::Failure;
+ }
+
+ auto* count_val = count_sem->ConstantValue();
+ if (!count_val) {
+ AddError("array size must evaluate to a constant integer expression", count_expr->source);
+ return utils::Failure;
+ }
+
+ if (auto* ty = count_val->Type(); !ty->is_integer_scalar()) {
+ AddError("array size must evaluate to a constant integer expression, but is type '" +
+ builder_->FriendlyName(ty) + "'",
+ count_expr->source);
+ return utils::Failure;
+ }
+
+ int64_t count = count_val->As<AInt>();
+ if (count < 1) {
+ AddError("array size (" + std::to_string(count) + ") must be greater than 0",
+ count_expr->source);
+ return utils::Failure;
+ }
+
+ return static_cast<uint32_t>(count);
+}
+
+bool Resolver::ArrayAttributes(const ast::AttributeList& attributes,
+ const sem::Type* el_ty,
+ uint32_t& explicit_stride) {
+ if (!validator_.NoDuplicateAttributes(attributes)) {
+ return false;
+ }
+
+ for (auto* attr : attributes) {
Mark(attr);
if (auto* sd = attr->As<ast::StrideAttribute>()) {
explicit_stride = sd->stride;
- if (!validator_.ArrayStrideAttribute(sd, el_size, el_align, source)) {
- return nullptr;
+ if (!validator_.ArrayStrideAttribute(sd, el_ty->Size(), el_ty->Align())) {
+ return false;
}
continue;
}
AddError("attribute is not valid for array types", attr->source);
- return nullptr;
+ return false;
}
- // Calculate implicit stride
- uint64_t implicit_stride = utils::RoundUp<uint64_t>(el_align, el_size);
+ return true;
+}
+sem::Array* Resolver::Array(const Source& source,
+ const sem::Type* el_ty,
+ uint32_t el_count,
+ uint32_t explicit_stride) {
+ uint32_t el_align = el_ty->Align();
+ uint32_t el_size = el_ty->Size();
+ uint64_t implicit_stride = el_size ? utils::RoundUp<uint64_t>(el_align, el_size) : 0;
uint64_t stride = explicit_stride ? explicit_stride : implicit_stride;
- int64_t count = 0; // sem::Array uses a size of 0 for a runtime-sized array.
-
- // Evaluate the constant array size expression.
- if (auto* count_expr = arr->count) {
- const auto* count_sem = Materialize(Expression(count_expr));
- if (!count_sem) {
- return nullptr;
- }
-
- auto* count_val = count_sem->ConstantValue();
- if (!count_val) {
- AddError("array size must evaluate to a constant integer expression",
- count_expr->source);
- return nullptr;
- }
-
- if (auto* ty = count_val->Type(); !ty->is_integer_scalar()) {
- AddError("array size must evaluate to a constant integer expression, but is type '" +
- builder_->FriendlyName(ty) + "'",
- count_expr->source);
- return nullptr;
- }
-
- count = count_val->As<AInt>();
- if (count < 1) {
- AddError("array size (" + std::to_string(count) + ") must be greater than 0",
- count_expr->source);
- return nullptr;
- }
- }
-
- auto size = std::max<uint64_t>(static_cast<uint32_t>(count), 1u) * stride;
+ auto size = std::max<uint64_t>(el_count, 1u) * stride;
if (size > std::numeric_limits<uint32_t>::max()) {
std::stringstream msg;
msg << "array size (0x" << std::hex << size << ") must not exceed 0xffffffff bytes";
- AddError(msg.str(), arr->source);
+ AddError(msg.str(), source);
return nullptr;
}
- auto* out = builder_->create<sem::Array>(
- elem_type, static_cast<uint32_t>(count), el_align, static_cast<uint32_t>(size),
- static_cast<uint32_t>(stride), static_cast<uint32_t>(implicit_stride));
+ auto* out = builder_->create<sem::Array>(el_ty, el_count, el_align, static_cast<uint32_t>(size),
+ static_cast<uint32_t>(stride),
+ static_cast<uint32_t>(implicit_stride));
if (!validator_.Array(out, source)) {
return nullptr;
}
- if (elem_type->Is<sem::Atomic>()) {
- atomic_composite_info_.emplace(out, arr->type->source);
- } else {
- auto found = atomic_composite_info_.find(elem_type);
- if (found != atomic_composite_info_.end()) {
- atomic_composite_info_.emplace(out, found->second);
- }
- }
-
return out;
}