sem: Fold together sem::Array and sem::ArrayType
There's now no need to have both.
Removes a whole bunch of Sem().Get() smell, and simplifies the resolver.
Also fixes a long-standing issue where an array with an explicit, but equal-to-implicit-stride attribute would result in a different type to an array without the decoration.
Bug: tint:724
Fixed: tint:782
Change-Id: I0202459009cd45be427cdb621993a5a3b07ff51e
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/50301
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 32273e2..a3162c2 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -178,8 +178,8 @@
if (type->is_scalar() || type->Is<sem::Vector>() || type->Is<sem::Matrix>()) {
return true;
}
- if (auto* arr = type->As<sem::ArrayType>()) {
- return IsStorable(arr->type());
+ if (auto* arr = type->As<sem::Array>()) {
+ return IsStorable(arr->ElemType());
}
if (auto* str = type->As<sem::Struct>()) {
for (const auto* member : str->Members()) {
@@ -204,8 +204,8 @@
if (auto* mat = type->As<sem::Matrix>()) {
return IsHostShareable(mat->type());
}
- if (auto* arr = type->As<sem::ArrayType>()) {
- return IsHostShareable(arr->type());
+ if (auto* arr = type->As<sem::Array>()) {
+ return IsHostShareable(arr->ElemType());
}
if (auto* str = type->As<sem::Struct>()) {
for (auto* member : str->Members()) {
@@ -287,7 +287,7 @@
// TODO(crbug.com/tint/724) - Remove once tint:724 is complete.
// ast::AccessDecorations are generated by the WGSL parser, used to
// build sem::AccessControls and then leaked.
- // ast::StrideDecoration are used to build a sem::ArrayTypes, but
+ // ast::StrideDecoration are used to build a sem::Arrays, but
// multiple arrays of the same stride, size and element type are
// currently de-duplicated by the type manager, and we leak these
// decorations.
@@ -350,14 +350,7 @@
return nullptr;
}
if (auto* t = ty->As<ast::Array>()) {
- if (auto* el = Type(t->type())) {
- auto* sem = builder_->create<sem::ArrayType>(
- const_cast<sem::Type*>(el), t->size(), t->decorations());
- if (Array(sem, ty->source())) {
- return sem;
- }
- }
- return nullptr;
+ return Array(t);
}
if (auto* t = ty->As<ast::Pointer>()) {
if (auto* el = Type(t->type())) {
@@ -420,9 +413,10 @@
return s;
}
-Resolver::VariableInfo* Resolver::Variable(ast::Variable* var,
- sem::Type* type, /* = nullptr */
- std::string type_name /* = "" */) {
+Resolver::VariableInfo* Resolver::Variable(
+ ast::Variable* var,
+ const sem::Type* type, /* = nullptr */
+ std::string type_name /* = "" */) {
auto it = variable_to_info_.find(var);
if (it != variable_to_info_.end()) {
return it->second;
@@ -436,18 +430,10 @@
return nullptr;
}
- auto* ctype = Canonical(type);
+ auto* ctype = Canonical(const_cast<sem::Type*>(type));
auto* info = variable_infos_.Create(var, ctype, type_name);
variable_to_info_.emplace(var, info);
- // TODO(bclayton): Why is this here? Needed?
- // Resolve variable's type
- if (auto* arr = info->type->As<sem::ArrayType>()) {
- if (!Array(arr, var->source())) {
- return nullptr;
- }
- }
-
return info;
}
@@ -596,8 +582,8 @@
bool Resolver::ValidateVariable(const ast::Variable* var) {
auto* type = variable_to_info_[var]->type;
- if (auto* r = type->As<sem::ArrayType>()) {
- if (r->IsRuntimeArray()) {
+ if (auto* r = type->As<sem::Array>()) {
+ if (r->IsRuntimeSized()) {
diagnostics_.add_error(
"v-0015",
"runtime arrays may only appear as the last member of a struct",
@@ -873,8 +859,8 @@
builder_->Symbols().NameFor(func->symbol()),
func->source());
return false;
- } else if (auto* arr = member_ty->As<sem::ArrayType>()) {
- if (arr->IsRuntimeArray()) {
+ } else if (auto* arr = member_ty->As<sem::Array>()) {
+ if (arr->IsRuntimeSized()) {
diagnostics_.add_error(
"entry point IO types cannot contain runtime sized arrays",
member->Declaration()->source());
@@ -1276,9 +1262,9 @@
auto* res = TypeOf(expr->array());
auto* parent_type = res->UnwrapAll();
- sem::Type* ret = nullptr;
- if (auto* arr = parent_type->As<sem::ArrayType>()) {
- ret = arr->type();
+ const sem::Type* ret = nullptr;
+ if (auto* arr = parent_type->As<sem::Array>()) {
+ ret = arr->ElemType();
} else if (auto* vec = parent_type->As<sem::Vector>()) {
ret = vec->type();
} else if (auto* mat = parent_type->As<sem::Matrix>()) {
@@ -1293,8 +1279,8 @@
// If we're extracting from a pointer, we return a pointer.
if (auto* ptr = res->As<sem::Pointer>()) {
ret = builder_->create<sem::Pointer>(ret, ptr->storage_class());
- } else if (auto* arr = parent_type->As<sem::ArrayType>()) {
- if (!arr->type()->is_scalar()) {
+ } else if (auto* arr = parent_type->As<sem::Array>()) {
+ if (!arr->ElemType()->is_scalar()) {
// If we extract a non-scalar from an array then we also get a pointer. We
// will generate a Function storage class variable to store this into.
ret = builder_->create<sem::Pointer>(ret, ast::StorageClass::kFunction);
@@ -1459,7 +1445,7 @@
value_cardinality_sum++;
} else if (auto* value_vec = value_type->As<sem::Vector>()) {
- sem::Type* value_elem_type = value_vec->type()->UnwrapAll();
+ auto* value_elem_type = value_vec->type()->UnwrapAll();
// A mismatch of vector type parameter T is only an error if multiple
// arguments are present. A single argument constructor constitutes a
// type conversion expression.
@@ -1754,8 +1740,8 @@
auto* lhs_declared_type = TypeOf(expr->lhs())->UnwrapAll();
auto* rhs_declared_type = TypeOf(expr->rhs())->UnwrapAll();
- auto* lhs_type = Canonical(lhs_declared_type);
- auto* rhs_type = Canonical(rhs_declared_type);
+ auto* lhs_type = Canonical(const_cast<sem::Type*>(lhs_declared_type));
+ auto* rhs_type = Canonical(const_cast<sem::Type*>(rhs_declared_type));
auto* lhs_vec = lhs_type->As<Vector>();
auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
@@ -2006,7 +1992,7 @@
// If the variable has a declared type, resolve it.
std::string type_name;
- sem::Type* type = nullptr;
+ const sem::Type* type = nullptr;
if (auto* ast_ty = var->type()) {
type_name = ast_ty->FriendlyName(builder_->Symbols());
type = Type(ast_ty);
@@ -2065,7 +2051,7 @@
}
// TODO(bclayton): Remove this and fix tests. We're overriding the semantic
// type stored in info->type here with a possibly non-canonicalized type.
- info->type = type;
+ info->type = const_cast<sem::Type*>(type);
variable_stack_.set(var->symbol(), info);
current_block_->decls.push_back(var);
@@ -2251,8 +2237,7 @@
bool Resolver::DefaultAlignAndSize(const sem::Type* ty,
uint32_t& align,
- uint32_t& size,
- const Source& source) {
+ uint32_t& size) {
static constexpr uint32_t vector_size[] = {
/* padding */ 0,
/* padding */ 0,
@@ -2297,76 +2282,71 @@
align = s->Align();
size = s->Size();
return true;
- } else if (cty->Is<sem::ArrayType>()) {
- if (auto* sem =
- Array(ty->UnwrapAliasIfNeeded()->As<sem::ArrayType>(), source)) {
- align = sem->Align();
- size = sem->Size();
- return true;
- }
- return false;
+ } else if (auto* a = cty->As<sem::Array>()) {
+ align = a->Align();
+ size = a->SizeInBytes();
+ return true;
}
TINT_UNREACHABLE(diagnostics_) << "Invalid type " << ty->TypeInfo().name;
return false;
}
-const sem::Array* Resolver::Array(const sem::ArrayType* arr,
- const Source& source) {
- if (auto* sem = builder_->Sem().Get(arr)) {
- // Semantic info already constructed for this array type
- return sem;
- }
+sem::Array* Resolver::Array(const ast::Array* arr) {
+ auto source = arr->source();
- if (!ValidateArray(arr, source)) {
+ auto* el_ty = Type(arr->type());
+ if (!el_ty) {
return nullptr;
}
- auto* el_ty = arr->type();
-
uint32_t el_align = 0;
uint32_t el_size = 0;
- if (!DefaultAlignAndSize(el_ty, el_align, el_size, source)) {
+ if (!DefaultAlignAndSize(el_ty, el_align, el_size)) {
return nullptr;
}
- auto create_semantic = [&](uint32_t stride) -> sem::Array* {
- auto align = el_align;
- // WebGPU requires runtime arrays have at least one element, but the AST
- // records an element count of 0 for it.
- auto size = std::max<uint32_t>(arr->size(), 1) * stride;
- auto* sem = builder_->create<sem::Array>(const_cast<sem::ArrayType*>(arr),
- align, size, stride);
- builder_->Sem().Add(arr, sem);
- return sem;
- };
-
// Look for explicit stride via [[stride(n)]] decoration
uint32_t explicit_stride = 0;
for (auto* deco : arr->decorations()) {
Mark(deco);
- if (auto* stride = deco->As<ast::StrideDecoration>()) {
+ if (auto* sd = deco->As<ast::StrideDecoration>()) {
if (explicit_stride) {
diagnostics_.add_error(
"array must have at most one [[stride]] decoration", source);
return nullptr;
}
- explicit_stride = stride->stride();
- if (!ValidateArrayStrideDecoration(stride, el_size, el_align, source)) {
+ explicit_stride = sd->stride();
+ if (!ValidateArrayStrideDecoration(sd, el_size, el_align, source)) {
return nullptr;
}
+ continue;
}
- }
- if (explicit_stride) {
- return create_semantic(explicit_stride);
+
+ diagnostics_.add_error("decoration is not valid for array types",
+ deco->source());
+ return nullptr;
}
// Calculate implicit stride
auto implicit_stride = utils::RoundUp(el_align, el_size);
- return create_semantic(implicit_stride);
+
+ auto stride = explicit_stride ? explicit_stride : implicit_stride;
+
+ // WebGPU requires runtime arrays have at least one element, but the AST
+ // records an element count of 0 for it.
+ auto size = std::max<uint32_t>(arr->size(), 1) * stride;
+ auto* sem = builder_->create<sem::Array>(el_ty, arr->size(), el_align, size,
+ stride, stride == implicit_stride);
+
+ if (!ValidateArray(sem, source)) {
+ return nullptr;
+ }
+
+ return sem;
}
-bool Resolver::ValidateArray(const sem::ArrayType* arr, const Source& source) {
- auto* el_ty = arr->type();
+bool Resolver::ValidateArray(const sem::Array* arr, const Source& source) {
+ auto* el_ty = arr->ElemType();
if (!IsStorable(el_ty)) {
builder_->Diagnostics().add_error(
@@ -2416,8 +2396,8 @@
bool Resolver::ValidateStructure(const sem::Struct* str) {
for (auto* member : str->Members()) {
- if (auto* r = member->Type()->UnwrapAll()->As<sem::ArrayType>()) {
- if (r->IsRuntimeArray()) {
+ if (auto* r = member->Type()->UnwrapAll()->As<sem::Array>()) {
+ if (r->IsRuntimeSized()) {
if (member != str->Members().back()) {
diagnostics_.add_error(
"v-0015",
@@ -2434,14 +2414,6 @@
member->Declaration()->source());
return false;
}
-
- for (auto* deco : r->decorations()) {
- if (!deco->Is<ast::StrideDecoration>()) {
- diagnostics_.add_error("decoration is not valid for array types",
- deco->source());
- return false;
- }
- }
}
}
@@ -2511,7 +2483,7 @@
uint32_t offset = struct_size;
uint32_t align = 0;
uint32_t size = 0;
- if (!DefaultAlignAndSize(type, align, size, member->source())) {
+ if (!DefaultAlignAndSize(type, align, size)) {
return nullptr;
}
@@ -2779,7 +2751,7 @@
bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
sem::Type* ty,
const Source& usage) {
- ty = ty->UnwrapIfNeeded();
+ ty = const_cast<sem::Type*>(ty->UnwrapIfNeeded());
if (auto* str = ty->As<sem::Struct>()) {
if (str->StorageClassUsage().count(sc)) {
@@ -2801,8 +2773,9 @@
return true;
}
- if (auto* arr = ty->As<sem::ArrayType>()) {
- return ApplyStorageClassUsageToType(sc, arr->type(), usage);
+ if (auto* arr = ty->As<sem::Array>()) {
+ return ApplyStorageClassUsageToType(
+ sc, const_cast<sem::Type*>(arr->ElemType()), usage);
}
if (ast::IsHostShareable(sc) && !IsHostShareable(ty)) {
@@ -2829,7 +2802,8 @@
return result;
}
-std::string Resolver::VectorPretty(uint32_t size, sem::Type* element_type) {
+std::string Resolver::VectorPretty(uint32_t size,
+ const sem::Type* element_type) {
sem::Vector vec_type(element_type, size);
return vec_type.FriendlyName(builder_->Symbols());
}