Rework Resolver so that we construct semantic types in a single pass.
The semantic nodes cannot be fully immutable, as they contain cyclic
references. Remove Resolver::CreateSemanticNodes(), and instead
construct and mutate the semantic nodes in the single traversal pass.
Give up on trying to maintain the 'authored' type names (aliased names).
These are a nightmare to maintain, and provided limited use.
Significantly simplfies the Resolver, and allows us to generate more
semantic to semantic references, reducing sem -> ast -> sem hops.
Note: This change introduces constant value propagation across constant
variables. This is unlocked by the earlier construction of the
sem::Variable.
Change-Id: I592092fdc47fe24d30e512952511c9ab7c16d7a1
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/68406
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/resolver/array_accessor_test.cc b/src/resolver/array_accessor_test.cc
index a565d0c..728c119 100644
--- a/src/resolver/array_accessor_test.cc
+++ b/src/resolver/array_accessor_test.cc
@@ -289,8 +289,9 @@
Func("func", {p}, ty.f32(), {Decl(idx), Decl(x), Return(x)});
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(),
- "12:34 error: cannot index type 'ptr<function, vec4<f32>>'");
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: cannot index type 'ptr<function, vec4<f32>, read_write>'");
}
TEST_F(ResolverArrayAccessorTest, Exr_Deref_BadParent) {
diff --git a/src/resolver/assignment_validation_test.cc b/src/resolver/assignment_validation_test.cc
index f8d9390..1981c87 100644
--- a/src/resolver/assignment_validation_test.cc
+++ b/src/resolver/assignment_validation_test.cc
@@ -102,7 +102,7 @@
ASSERT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: cannot assign 'array<f32, len>' to 'array<f32, 4>'");
+ "12:34 error: cannot assign 'array<f32, 5>' to 'array<f32, 4>'");
}
TEST_F(ResolverAssignmentValidationTest,
@@ -332,7 +332,7 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
- "12:34 error: cannot assign 'ref<storage, array<i32>, read>' to '_'. "
+ "12:34 error: cannot assign 'array<i32>' to '_'. "
"'_' can only be assigned a constructible, pointer, texture or sampler "
"type");
}
diff --git a/src/resolver/compound_statement_test.cc b/src/resolver/compound_statement_test.cc
index a13cac6..2ace78e 100644
--- a/src/resolver/compound_statement_test.cc
+++ b/src/resolver/compound_statement_test.cc
@@ -43,7 +43,7 @@
ASSERT_TRUE(s->Block()->Is<sem::FunctionBlockStatement>());
EXPECT_EQ(s->Block(), s->FindFirstParent<sem::BlockStatement>());
EXPECT_EQ(s->Block(), s->FindFirstParent<sem::FunctionBlockStatement>());
- EXPECT_EQ(s->Block()->As<sem::FunctionBlockStatement>()->Function(), f);
+ EXPECT_EQ(s->Function()->Declaration(), f);
EXPECT_EQ(s->Block()->Parent(), nullptr);
}
@@ -74,8 +74,7 @@
EXPECT_EQ(s->Block()->Parent(),
s->FindFirstParent<sem::FunctionBlockStatement>());
ASSERT_TRUE(s->Block()->Parent()->Is<sem::FunctionBlockStatement>());
- EXPECT_EQ(
- s->Block()->Parent()->As<sem::FunctionBlockStatement>()->Function(), f);
+ EXPECT_EQ(s->Function()->Declaration(), f);
EXPECT_EQ(s->Block()->Parent()->Parent(), nullptr);
}
}
@@ -118,7 +117,7 @@
EXPECT_TRUE(
Is<sem::FunctionBlockStatement>(s->Parent()->Parent()->Parent()));
- EXPECT_EQ(s->FindFirstParent<sem::FunctionBlockStatement>()->Function(), f);
+ EXPECT_EQ(s->Function()->Declaration(), f);
EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent(), nullptr);
}
@@ -144,7 +143,7 @@
s->FindFirstParent<sem::FunctionBlockStatement>());
EXPECT_TRUE(Is<sem::FunctionBlockStatement>(
s->Parent()->Parent()->Parent()->Parent()));
- EXPECT_EQ(s->FindFirstParent<sem::FunctionBlockStatement>()->Function(), f);
+ EXPECT_EQ(s->Function()->Declaration(), f);
EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent()->Parent(), nullptr);
}
@@ -213,12 +212,7 @@
Is<sem::FunctionBlockStatement>(s->Block()->Parent()->Parent()));
EXPECT_EQ(s->Block()->Parent()->Parent(),
s->FindFirstParent<sem::FunctionBlockStatement>());
- EXPECT_EQ(s->Block()
- ->Parent()
- ->Parent()
- ->As<sem::FunctionBlockStatement>()
- ->Function(),
- f);
+ EXPECT_EQ(s->Function()->Declaration(), f);
EXPECT_EQ(s->Block()->Parent()->Parent()->Parent(), nullptr);
}
}
diff --git a/src/resolver/function_validation_test.cc b/src/resolver/function_validation_test.cc
index dc71cae..ec44227 100644
--- a/src/resolver/function_validation_test.cc
+++ b/src/resolver/function_validation_test.cc
@@ -388,7 +388,7 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: return statement type must match its function return "
- "type, returned 'u32', expected 'myf32'");
+ "type, returned 'u32', expected 'f32'");
}
TEST_F(ResolverFunctionValidationTest, CannotCallEntryPoint) {
diff --git a/src/resolver/ptr_ref_test.cc b/src/resolver/ptr_ref_test.cc
index f842549..4810537 100644
--- a/src/resolver/ptr_ref_test.cc
+++ b/src/resolver/ptr_ref_test.cc
@@ -98,11 +98,16 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
- ASSERT_TRUE(TypeOf(function_ptr)->Is<sem::Pointer>());
- ASSERT_TRUE(TypeOf(private_ptr)->Is<sem::Pointer>());
- ASSERT_TRUE(TypeOf(workgroup_ptr)->Is<sem::Pointer>());
- ASSERT_TRUE(TypeOf(uniform_ptr)->Is<sem::Pointer>());
- ASSERT_TRUE(TypeOf(storage_ptr)->Is<sem::Pointer>());
+ ASSERT_TRUE(TypeOf(function_ptr)->Is<sem::Pointer>())
+ << "function_ptr is " << TypeOf(function_ptr)->TypeInfo().name;
+ ASSERT_TRUE(TypeOf(private_ptr)->Is<sem::Pointer>())
+ << "private_ptr is " << TypeOf(private_ptr)->TypeInfo().name;
+ ASSERT_TRUE(TypeOf(workgroup_ptr)->Is<sem::Pointer>())
+ << "workgroup_ptr is " << TypeOf(workgroup_ptr)->TypeInfo().name;
+ ASSERT_TRUE(TypeOf(uniform_ptr)->Is<sem::Pointer>())
+ << "uniform_ptr is " << TypeOf(uniform_ptr)->TypeInfo().name;
+ ASSERT_TRUE(TypeOf(storage_ptr)->Is<sem::Pointer>())
+ << "storage_ptr is " << TypeOf(storage_ptr)->TypeInfo().name;
EXPECT_EQ(TypeOf(function_ptr)->As<sem::Pointer>()->Access(),
ast::Access::kReadWrite);
diff --git a/src/resolver/ptr_ref_validation_test.cc b/src/resolver/ptr_ref_validation_test.cc
index 06daeb5..87886bb 100644
--- a/src/resolver/ptr_ref_validation_test.cc
+++ b/src/resolver/ptr_ref_validation_test.cc
@@ -167,7 +167,7 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: cannot initialize let of type "
- "'ptr<storage, i32>' with value of type "
+ "'ptr<storage, i32, read>' with value of type "
"'ptr<storage, i32, read_write>'");
}
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 9664f00..7808b8f 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -161,22 +161,6 @@
Resolver::~Resolver() = default;
-void Resolver::set_referenced_from_function_if_needed(VariableInfo* var,
- bool local) {
- if (current_function_ == nullptr) {
- return;
- }
-
- if (var->kind != VariableKind::kGlobal) {
- return;
- }
-
- current_function_->referenced_module_vars.add(var);
- if (local) {
- current_function_->local_referenced_module_vars.add(var);
- }
-}
-
bool Resolver::Resolve() {
if (builder_->Diagnostics().contains_errors()) {
return false;
@@ -190,23 +174,19 @@
return false;
}
- // Even if resolving failed, create all the semantic nodes for information we
- // did generate.
- CreateSemanticNodes();
-
return result;
}
// https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section
bool Resolver::IsPlain(const sem::Type* type) const {
- return type->is_scalar() || type->Is<sem::Atomic>() ||
- type->Is<sem::Vector>() || type->Is<sem::Matrix>() ||
- type->Is<sem::Array>() || type->Is<sem::Struct>();
+ return type->is_scalar() ||
+ type->IsAnyOf<sem::Atomic, sem::Vector, sem::Matrix, sem::Array,
+ sem::Struct>();
}
// https://gpuweb.github.io/gpuweb/wgsl.html#storable-types
bool Resolver::IsStorable(const sem::Type* type) const {
- return IsPlain(type) || type->Is<sem::Texture>() || type->Is<sem::Sampler>();
+ return IsPlain(type) || type->IsAnyOf<sem::Texture, sem::Sampler>();
}
// https://gpuweb.github.io/gpuweb/wgsl.html#host-shareable-types
@@ -443,53 +423,36 @@
return true;
}
-Resolver::VariableInfo* Resolver::Variable(const ast::Variable* var,
- VariableKind kind,
- uint32_t index /* = 0 */) {
- if (variable_to_info_.count(var)) {
- TINT_ICE(Resolver, diagnostics_)
- << "Variable " << builder_->Symbols().NameFor(var->symbol)
- << " already resolved";
- return nullptr;
- }
-
- std::string type_name;
- const sem::Type* storage_type = nullptr;
+sem::Variable* Resolver::Variable(const ast::Variable* var,
+ VariableKind kind,
+ uint32_t index /* = 0 */) {
+ const sem::Type* storage_ty = nullptr;
// If the variable has a declared type, resolve it.
if (auto* ty = var->type) {
- type_name = ty->FriendlyName(builder_->Symbols());
- storage_type = Type(ty);
- if (!storage_type) {
+ storage_ty = Type(ty);
+ if (!storage_ty) {
return nullptr;
}
}
- std::string rhs_type_name;
- const sem::Type* rhs_type = nullptr;
+ const sem::Expression* rhs = nullptr;
// Does the variable have a constructor?
- if (auto* ctor = var->constructor) {
- if (!Expression(var->constructor)) {
- return nullptr;
- }
-
- // Fetch the constructor's type
- rhs_type_name = TypeNameOf(ctor);
- rhs_type = TypeOf(ctor);
- if (!rhs_type) {
+ if (var->constructor) {
+ rhs = Expression(var->constructor);
+ if (!rhs) {
return nullptr;
}
// If the variable has no declared type, infer it from the RHS
- if (!storage_type) {
+ if (!storage_ty) {
if (!var->is_const && kind == VariableKind::kGlobal) {
AddError("global var declaration must specify a type", var->source);
return nullptr;
}
- type_name = rhs_type_name;
- storage_type = rhs_type->UnwrapRef(); // Implicit load of RHS
+ storage_ty = rhs->Type()->UnwrapRef(); // Implicit load of RHS
}
} else if (var->is_const && kind != VariableKind::kParameter &&
!ast::HasDecoration<ast::OverrideDecoration>(var->decorations)) {
@@ -504,7 +467,7 @@
return nullptr;
}
- if (!storage_type) {
+ if (!storage_ty) {
TINT_ICE(Resolver, diagnostics_)
<< "failed to determine storage type for variable '" +
builder_->Symbols().NameFor(var->symbol) + "'\n"
@@ -517,7 +480,7 @@
// No declared storage class. Infer from usage / type.
if (kind == VariableKind::kLocal) {
storage_class = ast::StorageClass::kFunction;
- } else if (storage_type->UnwrapRef()->is_handle()) {
+ } else if (storage_ty->UnwrapRef()->is_handle()) {
// https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables
// If the store type is a texture type or a sampler type, then the
// variable declaration must not have a storage class decoration. The
@@ -526,31 +489,97 @@
}
}
+ if (kind == VariableKind::kLocal && !var->is_const &&
+ storage_class != ast::StorageClass::kFunction &&
+ IsValidationEnabled(var->decorations,
+ ast::DisabledValidation::kIgnoreStorageClass)) {
+ AddError("function variable has a non-function storage class", var->source);
+ return nullptr;
+ }
+
auto access = var->declared_access;
if (access == ast::Access::kUndefined) {
access = DefaultAccessForStorageClass(storage_class);
}
- auto* type = storage_type;
+ auto* var_ty = storage_ty;
if (!var->is_const) {
// Variable declaration. Unlike `let`, `var` has storage.
// Variables are always of a reference type to the declared storage type.
- type =
- builder_->create<sem::Reference>(storage_type, storage_class, access);
+ var_ty =
+ builder_->create<sem::Reference>(storage_ty, storage_class, access);
}
- if (rhs_type &&
- !ValidateVariableConstructor(var, storage_class, storage_type, type_name,
- rhs_type, rhs_type_name)) {
+ if (rhs && !ValidateVariableConstructor(var, storage_class, storage_ty,
+ rhs->Type())) {
return nullptr;
}
- auto* info =
- variable_infos_.Create(var, const_cast<sem::Type*>(type), type_name,
- storage_class, access, kind, index);
- variable_to_info_.emplace(var, info);
+ if (!ApplyStorageClassUsageToType(
+ storage_class, const_cast<sem::Type*>(var_ty), var->source)) {
+ AddNote(
+ std::string("while instantiating ") +
+ ((kind == VariableKind::kParameter) ? "parameter " : "variable ") +
+ builder_->Symbols().NameFor(var->symbol),
+ var->source);
+ return nullptr;
+ }
- return info;
+ if (kind == VariableKind::kParameter) {
+ if (auto* ptr = var_ty->As<sem::Pointer>()) {
+ // For MSL, we push module-scope variables into the entry point as pointer
+ // parameters, so we also need to handle their store type.
+ if (!ApplyStorageClassUsageToType(
+ ptr->StorageClass(), const_cast<sem::Type*>(ptr->StoreType()),
+ var->source)) {
+ AddNote("while instantiating parameter " +
+ builder_->Symbols().NameFor(var->symbol),
+ var->source);
+ return nullptr;
+ }
+ }
+ }
+
+ switch (kind) {
+ case VariableKind::kGlobal: {
+ sem::BindingPoint binding_point;
+ if (auto bp = var->BindingPoint()) {
+ binding_point = {bp.group->value, bp.binding->value};
+ }
+
+ auto* global = builder_->create<sem::GlobalVariable>(
+ var, var_ty, storage_class, access,
+ (rhs && var->is_const) ? rhs->ConstantValue() : sem::Constant{},
+ binding_point);
+
+ if (auto* override =
+ ast::GetDecoration<ast::OverrideDecoration>(var->decorations)) {
+ if (override->has_value) {
+ global->SetConstantId(static_cast<uint16_t>(override->value));
+ }
+ }
+
+ builder_->Sem().Add(var, global);
+ return global;
+ }
+ case VariableKind::kLocal: {
+ auto* local = builder_->create<sem::LocalVariable>(
+ var, var_ty, storage_class, access,
+ (rhs && var->is_const) ? rhs->ConstantValue() : sem::Constant{});
+ builder_->Sem().Add(var, local);
+ return local;
+ }
+ case VariableKind::kParameter: {
+ auto* param = builder_->create<sem::Parameter>(var, index, var_ty,
+ storage_class, access);
+ builder_->Sem().Add(var, param);
+ return param;
+ }
+ }
+
+ TINT_UNREACHABLE(Resolver, diagnostics_)
+ << "unhandled VariableKind " << static_cast<int>(kind);
+ return nullptr;
}
ast::Access Resolver::DefaultAccessForStorageClass(
@@ -603,23 +632,23 @@
next_constant_id = constant_id + 1;
}
- variable_to_info_[var]->constant_id = constant_id;
+ auto* sem = Sem<sem::GlobalVariable>(var);
+ const_cast<sem::GlobalVariable*>(sem)->SetConstantId(constant_id);
}
}
bool Resolver::ValidateVariableConstructor(const ast::Variable* var,
ast::StorageClass storage_class,
- const sem::Type* storage_type,
- const std::string& type_name,
- const sem::Type* rhs_type,
- const std::string& rhs_type_name) {
- auto* value_type = rhs_type->UnwrapRef(); // Implicit load of RHS
+ const sem::Type* storage_ty,
+ const sem::Type* rhs_ty) {
+ auto* value_type = rhs_ty->UnwrapRef(); // Implicit load of RHS
// Value type has to match storage type
- if (storage_type != value_type) {
+ if (storage_ty != value_type) {
std::string decl = var->is_const ? "let" : "var";
- AddError("cannot initialize " + decl + " of type '" + type_name +
- "' with value of type '" + rhs_type_name + "'",
+ AddError("cannot initialize " + decl + " of type '" +
+ TypeNameOf(storage_ty) + "' with value of type '" +
+ TypeNameOf(rhs_ty) + "'",
var->source);
return false;
}
@@ -652,17 +681,18 @@
return false;
}
- auto* info = Variable(var, VariableKind::kGlobal);
- if (!info) {
+ auto* sem = Variable(var, VariableKind::kGlobal);
+ if (!sem) {
return false;
}
- variable_stack_.Set(var->symbol, info);
+ variable_stack_.Set(var->symbol, sem);
- if (!var->is_const && info->storage_class == ast::StorageClass::kNone) {
+ auto storage_class = sem->StorageClass();
+ if (!var->is_const && storage_class == ast::StorageClass::kNone) {
AddError("global variables must have a storage class", var->source);
return false;
}
- if (var->is_const && !(info->storage_class == ast::StorageClass::kNone)) {
+ if (var->is_const && storage_class != ast::StorageClass::kNone) {
AddError("global constants shouldn't have a storage class", var->source);
return false;
}
@@ -673,7 +703,7 @@
if (auto* override_deco = deco->As<ast::OverrideDecoration>()) {
// Track the constant IDs that are specified in the shader.
if (override_deco->has_value) {
- constant_ids_.emplace(override_deco->value, info);
+ constant_ids_.emplace(override_deco->value, sem);
}
}
}
@@ -682,26 +712,13 @@
return false;
}
- if (auto bp = var->BindingPoint()) {
- info->binding_point = {bp.group->value, bp.binding->value};
- }
-
- if (!ValidateGlobalVariable(info)) {
- return false;
- }
-
- if (!ApplyStorageClassUsageToType(
- info->storage_class, const_cast<sem::Type*>(info->type->UnwrapRef()),
- var->source)) {
- AddNote("while instantiating variable " +
- builder_->Symbols().NameFor(var->symbol),
- var->source);
+ if (!ValidateGlobalVariable(sem)) {
return false;
}
// TODO(bclayton): Call this at the end of resolve on all uniform and storage
// referenced structs
- if (!ValidateStorageClassLayout(info)) {
+ if (!ValidateStorageClassLayout(sem)) {
return false;
}
@@ -735,7 +752,7 @@
};
auto type_name_of = [this](const sem::StructMember* sm) {
- return sm->Declaration()->type->FriendlyName(builder_->Symbols());
+ return TypeNameOf(sm->Type());
};
// TODO(amaiorano): Output struct and member decorations so that this output
@@ -779,8 +796,7 @@
<< size << ") */ " << s << ";\n";
};
- print_struct_begin_line(st->Align(), st->Size(),
- st->FriendlyName(builder_->Symbols()));
+ print_struct_begin_line(st->Align(), st->Size(), TypeNameOf(st));
for (size_t i = 0; i < st->Members().size(); ++i) {
auto* const m = st->Members()[i];
@@ -911,10 +927,10 @@
return true;
}
-bool Resolver::ValidateStorageClassLayout(const VariableInfo* info) {
- if (auto* str = info->type->UnwrapRef()->As<sem::Struct>()) {
- if (!ValidateStorageClassLayout(str, info->storage_class)) {
- AddNote("see declaration of variable", info->declaration->source);
+bool Resolver::ValidateStorageClassLayout(const sem::Variable* var) {
+ if (auto* str = var->Type()->UnwrapRef()->As<sem::Struct>()) {
+ if (!ValidateStorageClassLayout(str, var->StorageClass())) {
+ AddNote("see declaration of variable", var->Declaration()->source);
return false;
}
}
@@ -922,24 +938,25 @@
return true;
}
-bool Resolver::ValidateGlobalVariable(const VariableInfo* info) {
- if (!ValidateNoDuplicateDecorations(info->declaration->decorations)) {
+bool Resolver::ValidateGlobalVariable(const sem::Variable* var) {
+ auto* decl = var->Declaration();
+ if (!ValidateNoDuplicateDecorations(decl->decorations)) {
return false;
}
- for (auto* deco : info->declaration->decorations) {
- if (info->declaration->is_const) {
+ for (auto* deco : decl->decorations) {
+ if (decl->is_const) {
if (auto* override_deco = deco->As<ast::OverrideDecoration>()) {
if (override_deco->has_value) {
uint32_t id = override_deco->value;
- auto itr = constant_ids_.find(id);
- if (itr != constant_ids_.end() && itr->second != info) {
+ auto it = constant_ids_.find(id);
+ if (it != constant_ids_.end() && it->second != var) {
AddError("pipeline constant IDs must be unique", deco->source);
AddNote("a pipeline constant with an ID of " + std::to_string(id) +
" was previously declared "
"here:",
ast::GetDecoration<ast::OverrideDecoration>(
- itr->second->declaration->decorations)
+ it->second->Declaration()->decorations)
->source);
return false;
}
@@ -958,8 +975,8 @@
deco->IsAnyOf<ast::BuiltinDecoration, ast::InterpolateDecoration,
ast::InvariantDecoration, ast::LocationDecoration>();
bool has_io_storage_class =
- info->storage_class == ast::StorageClass::kInput ||
- info->storage_class == ast::StorageClass::kOutput;
+ var->StorageClass() == ast::StorageClass::kInput ||
+ var->StorageClass() == ast::StorageClass::kOutput;
if (!(deco->IsAnyOf<ast::BindingDecoration, ast::GroupDecoration,
ast::InternalDecoration>()) &&
(!is_shader_io_decoration || !has_io_storage_class)) {
@@ -969,8 +986,8 @@
}
}
- auto binding_point = info->declaration->BindingPoint();
- switch (info->storage_class) {
+ auto binding_point = decl->BindingPoint();
+ switch (var->StorageClass()) {
case ast::StorageClass::kUniform:
case ast::StorageClass::kStorage:
case ast::StorageClass::kUniformConstant: {
@@ -981,7 +998,7 @@
AddError(
"resource variables require [[group]] and [[binding]] "
"decorations",
- info->declaration->source);
+ decl->source);
return false;
}
break;
@@ -993,7 +1010,7 @@
AddError(
"non-resource variables must not have [[group]] or [[binding]] "
"decorations",
- info->declaration->source);
+ decl->source);
return false;
}
}
@@ -1001,28 +1018,28 @@
// https://gpuweb.github.io/gpuweb/wgsl/#variable-declaration
// The access mode always has a default, and except for variables in the
// storage storage class, must not be written.
- if (info->storage_class != ast::StorageClass::kStorage &&
- info->declaration->declared_access != ast::Access::kUndefined) {
+ if (var->StorageClass() != ast::StorageClass::kStorage &&
+ decl->declared_access != ast::Access::kUndefined) {
AddError(
"only variables in <storage> storage class may declare an access mode",
- info->declaration->source);
+ decl->source);
return false;
}
- switch (info->storage_class) {
+ switch (var->StorageClass()) {
case ast::StorageClass::kStorage: {
// https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables
// A variable in the storage storage class is a storage buffer variable.
// Its store type must be a host-shareable structure type with block
// attribute, satisfying the storage class constraints.
- auto* str = info->type->UnwrapRef()->As<sem::Struct>();
+ auto* str = var->Type()->UnwrapRef()->As<sem::Struct>();
if (!str) {
AddError(
"variables declared in the <storage> storage class must be of a "
"structure type",
- info->declaration->source);
+ decl->source);
return false;
}
@@ -1031,9 +1048,8 @@
"structure used as a storage buffer must be declared with the "
"[[block]] decoration",
str->Declaration()->source);
- if (info->declaration->source.range.begin.line) {
- AddNote("structure used as storage buffer here",
- info->declaration->source);
+ if (decl->source.range.begin.line) {
+ AddNote("structure used as storage buffer here", decl->source);
}
return false;
}
@@ -1044,12 +1060,12 @@
// A variable in the uniform storage class is a uniform buffer variable.
// Its store type must be a host-shareable structure type with block
// attribute, satisfying the storage class constraints.
- auto* str = info->type->UnwrapRef()->As<sem::Struct>();
+ auto* str = var->Type()->UnwrapRef()->As<sem::Struct>();
if (!str) {
AddError(
"variables declared in the <uniform> storage class must be of a "
"structure type",
- info->declaration->source);
+ decl->source);
return false;
}
@@ -1058,9 +1074,8 @@
"structure used as a uniform buffer must be declared with the "
"[[block]] decoration",
str->Declaration()->source);
- if (info->declaration->source.range.begin.line) {
- AddNote("structure used as uniform buffer here",
- info->declaration->source);
+ if (decl->source.range.begin.line) {
+ AddNote("structure used as uniform buffer here", decl->source);
}
return false;
}
@@ -1071,7 +1086,7 @@
AddError(
"structure containing a runtime sized array "
"cannot be used as a uniform buffer",
- info->declaration->source);
+ decl->source);
AddNote("structure is declared here", str->Declaration()->source);
return false;
}
@@ -1084,24 +1099,24 @@
break;
}
- if (!info->declaration->is_const) {
- if (!ValidateAtomicVariable(info)) {
+ if (!decl->is_const) {
+ if (!ValidateAtomicVariable(var)) {
return false;
}
}
- return ValidateVariable(info);
+ return ValidateVariable(var);
}
// https://gpuweb.github.io/gpuweb/wgsl/#atomic-types
// Atomic types may only be instantiated by variables in the workgroup storage
// class or by storage buffer variables with a read_write access mode.
-bool Resolver::ValidateAtomicVariable(const VariableInfo* info) {
- auto sc = info->storage_class;
- auto access = info->access;
- auto* type = info->type->UnwrapRef();
- auto source = info->declaration->type ? info->declaration->type->source
- : info->declaration->source;
+bool Resolver::ValidateAtomicVariable(const sem::Variable* var) {
+ auto sc = var->StorageClass();
+ auto* decl = var->Declaration();
+ auto access = var->Access();
+ auto* type = var->Type()->UnwrapRef();
+ auto source = decl->type ? decl->type->source : decl->source;
if (type->Is<sem::Atomic>()) {
if (sc != ast::StorageClass::kWorkgroup) {
@@ -1118,10 +1133,9 @@
AddError(
"atomic variables must have <storage> or <workgroup> storage class",
source);
- AddNote("atomic sub-type of '" +
- type->FriendlyName(builder_->Symbols()) +
- "' is declared here",
- found->second);
+ AddNote(
+ "atomic sub-type of '" + TypeNameOf(type) + "' is declared here",
+ found->second);
return false;
} else if (sc == ast::StorageClass::kStorage &&
access != ast::Access::kReadWrite) {
@@ -1129,10 +1143,9 @@
"atomic variables in <storage> storage class must have read_write "
"access mode",
source);
- AddNote("atomic sub-type of '" +
- type->FriendlyName(builder_->Symbols()) +
- "' is declared here",
- found->second);
+ AddNote(
+ "atomic sub-type of '" + TypeNameOf(type) + "' is declared here",
+ found->second);
return false;
}
}
@@ -1141,75 +1154,85 @@
return true;
}
-bool Resolver::ValidateVariable(const VariableInfo* info) {
- auto* var = info->declaration;
- auto* storage_type = info->type->UnwrapRef();
+bool Resolver::ValidateVariable(const sem::Variable* var) {
+ auto* decl = var->Declaration();
+ auto* storage_ty = var->Type()->UnwrapRef();
- if (!var->is_const && !IsStorable(storage_type)) {
- AddError(storage_type->FriendlyName(builder_->Symbols()) +
- " cannot be used as the type of a var",
- var->source);
+ if (!decl->is_const && !IsStorable(storage_ty)) {
+ AddError(TypeNameOf(storage_ty) + " cannot be used as the type of a var",
+ decl->source);
return false;
}
- if (var->is_const && info->kind != VariableKind::kParameter &&
- !(storage_type->IsConstructible() || storage_type->Is<sem::Pointer>())) {
- AddError(storage_type->FriendlyName(builder_->Symbols()) +
- " cannot be used as the type of a let",
- var->source);
+ if (decl->is_const && !var->Is<sem::Parameter>() &&
+ !(storage_ty->IsConstructible() || storage_ty->Is<sem::Pointer>())) {
+ AddError(TypeNameOf(storage_ty) + " cannot be used as the type of a let",
+ decl->source);
return false;
}
- if (auto* r = storage_type->As<sem::Array>()) {
+ if (auto* r = storage_ty->As<sem::Array>()) {
if (r->IsRuntimeSized()) {
AddError("runtime arrays may only appear as the last member of a struct",
- var->source);
+ decl->source);
return false;
}
}
- if (auto* r = storage_type->As<sem::MultisampledTexture>()) {
+ if (auto* r = storage_ty->As<sem::MultisampledTexture>()) {
if (r->dim() != ast::TextureDimension::k2d) {
- AddError("only 2d multisampled textures are supported", var->source);
+ AddError("only 2d multisampled textures are supported", decl->source);
return false;
}
if (!r->type()->UnwrapRef()->is_numeric_scalar()) {
AddError("texture_multisampled_2d<type>: type must be f32, i32 or u32",
- var->source);
+ decl->source);
return false;
}
}
- if (storage_type->is_handle() &&
- var->declared_storage_class != ast::StorageClass::kNone) {
+ if (var->Is<sem::LocalVariable>() && !decl->is_const &&
+ IsValidationEnabled(decl->decorations,
+ ast::DisabledValidation::kIgnoreStorageClass)) {
+ if (!var->Type()->UnwrapRef()->IsConstructible()) {
+ AddError("function variable must have a constructible type",
+ decl->type ? decl->type->source : decl->source);
+ return false;
+ }
+ }
+
+ if (storage_ty->is_handle() &&
+ decl->declared_storage_class != ast::StorageClass::kNone) {
// https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables
// If the store type is a texture type or a sampler type, then the
// variable declaration must not have a storage class decoration. The
// storage class will always be handle.
- AddError("variables of type '" + info->type_name +
+ AddError("variables of type '" + TypeNameOf(storage_ty) +
"' must not have a storage class",
- var->source);
+ decl->source);
return false;
}
- if (IsValidationEnabled(var->decorations,
+ if (IsValidationEnabled(decl->decorations,
ast::DisabledValidation::kIgnoreStorageClass) &&
- (var->declared_storage_class == ast::StorageClass::kInput ||
- var->declared_storage_class == ast::StorageClass::kOutput)) {
- AddError("invalid use of input/output storage class", var->source);
+ (decl->declared_storage_class == ast::StorageClass::kInput ||
+ decl->declared_storage_class == ast::StorageClass::kOutput)) {
+ AddError("invalid use of input/output storage class", decl->source);
return false;
}
return true;
}
bool Resolver::ValidateFunctionParameter(const ast::Function* func,
- const VariableInfo* info) {
- if (!ValidateVariable(info)) {
+ const sem::Variable* var) {
+ if (!ValidateVariable(var)) {
return false;
}
- for (auto* deco : info->declaration->decorations) {
+ auto* decl = var->Declaration();
+
+ for (auto* deco : decl->decorations) {
if (!func->IsEntryPoint() && !deco->Is<ast::InternalDecoration>()) {
AddError(
"decoration is not valid for non-entry point function parameters",
@@ -1220,10 +1243,10 @@
ast::InterpolateDecoration,
ast::InternalDecoration>() &&
(IsValidationEnabled(
- info->declaration->decorations,
+ decl->decorations,
ast::DisabledValidation::kEntryPointParameter) &&
IsValidationEnabled(
- info->declaration->decorations,
+ decl->decorations,
ast::DisabledValidation::
kIgnoreConstructibleFunctionParameter))) {
AddError("decoration is not valid for function parameters", deco->source);
@@ -1231,34 +1254,35 @@
}
}
- if (auto* ref = info->type->As<sem::Pointer>()) {
+ if (auto* ref = var->Type()->As<sem::Pointer>()) {
auto sc = ref->StorageClass();
if (!(sc == ast::StorageClass::kFunction ||
sc == ast::StorageClass::kPrivate ||
sc == ast::StorageClass::kWorkgroup) &&
- IsValidationEnabled(info->declaration->decorations,
+ IsValidationEnabled(decl->decorations,
ast::DisabledValidation::kIgnoreStorageClass)) {
std::stringstream ss;
ss << "function parameter of pointer type cannot be in '" << sc
<< "' storage class";
- AddError(ss.str(), info->declaration->source);
+ AddError(ss.str(), decl->source);
return false;
}
}
- if (IsPlain(info->type)) {
- if (!info->type->IsConstructible() &&
+ if (IsPlain(var->Type())) {
+ if (!var->Type()->IsConstructible() &&
IsValidationEnabled(
- info->declaration->decorations,
+ decl->decorations,
ast::DisabledValidation::kIgnoreConstructibleFunctionParameter)) {
AddError("store type of function parameter must be a constructible type",
- info->declaration->source);
+ decl->source);
return false;
}
- } else if (!info->type->IsAnyOf<sem::Texture, sem::Sampler, sem::Pointer>()) {
- AddError("store type of function parameter cannot be " +
- info->type->FriendlyName(builder_->Symbols()),
- info->declaration->source);
+ } else if (!var->Type()
+ ->IsAnyOf<sem::Texture, sem::Sampler, sem::Pointer>()) {
+ AddError(
+ "store type of function parameter cannot be " + TypeNameOf(var->Type()),
+ decl->source);
return false;
}
@@ -1266,11 +1290,11 @@
}
bool Resolver::ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
- const sem::Type* storage_type,
+ const sem::Type* storage_ty,
const bool is_input) {
- auto* type = storage_type->UnwrapRef();
+ auto* type = storage_ty->UnwrapRef();
const auto stage = current_function_
- ? current_function_->declaration->PipelineStage()
+ ? current_function_->Declaration()->PipelineStage()
: ast::PipelineStage::kNone;
std::stringstream stage_name;
stage_name << stage;
@@ -1388,8 +1412,8 @@
bool Resolver::ValidateInterpolateDecoration(
const ast::InterpolateDecoration* deco,
- const sem::Type* storage_type) {
- auto* type = storage_type->UnwrapRef();
+ const sem::Type* storage_ty) {
+ auto* type = storage_ty->UnwrapRef();
if (type->is_integer_scalar_or_vector() &&
deco->type != ast::InterpolationType::kFlat) {
@@ -1409,18 +1433,18 @@
return true;
}
-bool Resolver::ValidateFunction(const ast::Function* func,
- const FunctionInfo* info) {
- if (!ValidateNoDuplicateDefinition(func->symbol, func->source,
+bool Resolver::ValidateFunction(const sem::Function* func) {
+ auto* decl = func->Declaration();
+ if (!ValidateNoDuplicateDefinition(decl->symbol, decl->source,
/* check_global_scope_only */ true)) {
return false;
}
auto workgroup_deco_count = 0;
- for (auto* deco : func->decorations) {
+ for (auto* deco : decl->decorations) {
if (deco->Is<ast::WorkgroupDecoration>()) {
workgroup_deco_count++;
- if (func->PipelineStage() != ast::PipelineStage::kCompute) {
+ if (decl->PipelineStage() != ast::PipelineStage::kCompute) {
AddError(
"the workgroup_size attribute is only valid for compute stages",
deco->source);
@@ -1433,41 +1457,41 @@
}
}
- if (func->params.size() > 255) {
- AddError("functions may declare at most 255 parameters", func->source);
+ if (decl->params.size() > 255) {
+ AddError("functions may declare at most 255 parameters", decl->source);
return false;
}
- for (auto* param : func->params) {
- if (!ValidateFunctionParameter(func, variable_to_info_.at(param))) {
+ for (size_t i = 0; i < decl->params.size(); i++) {
+ if (!ValidateFunctionParameter(decl, func->Parameters()[i])) {
return false;
}
}
- if (!info->return_type->Is<sem::Void>()) {
- if (!info->return_type->IsConstructible()) {
+ if (!func->ReturnType()->Is<sem::Void>()) {
+ if (!func->ReturnType()->IsConstructible()) {
AddError("function return type must be a constructible type",
- func->return_type->source);
+ decl->return_type->source);
return false;
}
- if (func->body) {
- if (!func->body->Last() ||
- !func->body->Last()->Is<ast::ReturnStatement>()) {
+ if (decl->body) {
+ if (!decl->body->Last() ||
+ !decl->body->Last()->Is<ast::ReturnStatement>()) {
AddError("non-void function must end with a return statement",
- func->source);
+ decl->source);
return false;
}
} else if (IsValidationEnabled(
- func->decorations,
+ decl->decorations,
ast::DisabledValidation::kFunctionHasNoBody)) {
TINT_ICE(Resolver, diagnostics_)
- << "Function " << builder_->Symbols().NameFor(func->symbol)
+ << "Function " << builder_->Symbols().NameFor(decl->symbol)
<< " has no body";
}
- for (auto* deco : func->return_type_decorations) {
- if (!func->IsEntryPoint()) {
+ for (auto* deco : decl->return_type_decorations) {
+ if (!decl->IsEntryPoint()) {
AddError(
"decoration is not valid for non-entry point function return types",
deco->source);
@@ -1476,9 +1500,9 @@
if (!deco->IsAnyOf<ast::BuiltinDecoration, ast::InternalDecoration,
ast::LocationDecoration, ast::InterpolateDecoration,
ast::InvariantDecoration>() &&
- (IsValidationEnabled(info->declaration->decorations,
+ (IsValidationEnabled(decl->decorations,
ast::DisabledValidation::kEntryPointParameter) &&
- IsValidationEnabled(info->declaration->decorations,
+ IsValidationEnabled(decl->decorations,
ast::DisabledValidation::
kIgnoreConstructibleFunctionParameter))) {
AddError("decoration is not valid for entry point return types",
@@ -1488,8 +1512,8 @@
}
}
- if (func->IsEntryPoint()) {
- if (!ValidateEntryPoint(func, info)) {
+ if (decl->IsEntryPoint()) {
+ if (!ValidateEntryPoint(func)) {
return false;
}
}
@@ -1497,12 +1521,13 @@
return true;
}
-bool Resolver::ValidateEntryPoint(const ast::Function* func,
- const FunctionInfo* info) {
+bool Resolver::ValidateEntryPoint(const sem::Function* func) {
+ auto* decl = func->Declaration();
+
// Use a lambda to validate the entry point decorations for a type.
// Persistent state is used to track which builtins and locations have
// already been seen, in order to catch conflicts.
- // TODO(jrprice): This state could be stored in FunctionInfo instead, and
+ // TODO(jrprice): This state could be stored in sem::Function instead, and
// then passed to sem::Function since it would be useful there too.
std::unordered_set<ast::Builtin> builtins;
std::unordered_set<uint32_t> locations;
@@ -1514,7 +1539,7 @@
// Inner lambda that is applied to a type and all of its members.
auto validate_entry_point_decorations_inner = [&](const ast::DecorationList&
decos,
- sem::Type* ty,
+ const sem::Type* ty,
Source source,
ParamOrRetType param_or_ret,
bool is_struct_member) {
@@ -1539,7 +1564,7 @@
" attribute appears multiple times as pipeline " +
(param_or_ret == ParamOrRetType::kParameter ? "input"
: "output"),
- func->source);
+ decl->source);
return false;
}
@@ -1564,14 +1589,14 @@
return false;
}
} else if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) {
- if (func->PipelineStage() == ast::PipelineStage::kCompute) {
+ if (decl->PipelineStage() == ast::PipelineStage::kCompute) {
is_invalid_compute_shader_decoration = true;
} else if (!ValidateInterpolateDecoration(interpolate, ty)) {
return false;
}
interpolate_attribute = interpolate;
} else if (auto* invariant = deco->As<ast::InvariantDecoration>()) {
- if (func->PipelineStage() == ast::PipelineStage::kCompute) {
+ if (decl->PipelineStage() == ast::PipelineStage::kCompute) {
is_invalid_compute_shader_decoration = true;
}
invariant_attribute = invariant;
@@ -1609,14 +1634,14 @@
if (ty->is_integer_scalar_or_vector() && !interpolate_attribute) {
// TODO(crbug.com/tint/1224): Make these errors once downstream
// usages have caught up (no sooner than M99).
- if (func->PipelineStage() == ast::PipelineStage::kVertex &&
+ if (decl->PipelineStage() == ast::PipelineStage::kVertex &&
param_or_ret == ParamOrRetType::kReturnType) {
AddWarning(
"integral user-defined vertex outputs must have a flat "
"interpolation attribute",
source);
}
- if (func->PipelineStage() == ast::PipelineStage::kFragment &&
+ if (decl->PipelineStage() == ast::PipelineStage::kFragment &&
param_or_ret == ParamOrRetType::kParameter) {
AddWarning(
"integral user-defined fragment inputs must have a flat "
@@ -1648,7 +1673,8 @@
// Outer lambda for validating the entry point decorations for a type.
auto validate_entry_point_decorations = [&](const ast::DecorationList& decos,
- sem::Type* ty, Source source,
+ const sem::Type* ty,
+ Source source,
ParamOrRetType param_or_ret) {
if (!validate_entry_point_decorations_inner(decos, ty, source, param_or_ret,
/*is_struct_member*/ false)) {
@@ -1662,8 +1688,8 @@
member->Declaration()->source, param_or_ret,
/*is_struct_member*/ true)) {
AddNote("while analysing entry point '" +
- builder_->Symbols().NameFor(func->symbol) + "'",
- func->source);
+ builder_->Symbols().NameFor(decl->symbol) + "'",
+ decl->source);
return false;
}
}
@@ -1672,10 +1698,11 @@
return true;
};
- for (auto* param : info->parameters) {
- if (!validate_entry_point_decorations(
- param->declaration->decorations, param->type,
- param->declaration->source, ParamOrRetType::kParameter)) {
+ for (auto* param : func->Parameters()) {
+ auto* param_decl = param->Declaration();
+ if (!validate_entry_point_decorations(param_decl->decorations,
+ param->Type(), param_decl->source,
+ ParamOrRetType::kParameter)) {
return false;
}
}
@@ -1686,21 +1713,21 @@
builtins.clear();
locations.clear();
- if (!info->return_type->Is<sem::Void>()) {
- if (!validate_entry_point_decorations(func->return_type_decorations,
- info->return_type, func->source,
+ if (!func->ReturnType()->Is<sem::Void>()) {
+ if (!validate_entry_point_decorations(decl->return_type_decorations,
+ func->ReturnType(), decl->source,
ParamOrRetType::kReturnType)) {
return false;
}
}
- if (func->PipelineStage() == ast::PipelineStage::kVertex &&
+ if (decl->PipelineStage() == ast::PipelineStage::kVertex &&
builtins.count(ast::Builtin::kPosition) == 0) {
// Check module-scope variables, as the SPIR-V sanitizer generates these.
bool found = false;
- for (auto* var : info->referenced_module_vars) {
+ for (auto* global : func->TransitivelyReferencedGlobals()) {
if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(
- var->declaration->decorations)) {
+ global->Declaration()->decorations)) {
if (builtin->builtin == ast::Builtin::kPosition) {
found = true;
break;
@@ -1711,31 +1738,32 @@
AddError(
"a vertex shader must include the 'position' builtin in its return "
"type",
- func->source);
+ decl->source);
return false;
}
}
- if (func->PipelineStage() == ast::PipelineStage::kCompute) {
- if (!ast::HasDecoration<ast::WorkgroupDecoration>(func->decorations)) {
+ if (decl->PipelineStage() == ast::PipelineStage::kCompute) {
+ if (!ast::HasDecoration<ast::WorkgroupDecoration>(decl->decorations)) {
AddError(
"a compute shader must include 'workgroup_size' in its "
"attributes",
- func->source);
+ decl->source);
return false;
}
}
// Validate there are no resource variable binding collisions
std::unordered_map<sem::BindingPoint, const ast::Variable*> binding_points;
- for (auto* var_info : info->referenced_module_vars) {
- if (!var_info->declaration->BindingPoint()) {
+ for (auto* var : func->TransitivelyReferencedGlobals()) {
+ auto* var_decl = var->Declaration();
+ if (!var_decl->BindingPoint()) {
continue;
}
- auto bp = var_info->binding_point;
- auto res = binding_points.emplace(bp, var_info->declaration);
+ auto bp = var->BindingPoint();
+ auto res = binding_points.emplace(bp, var_decl);
if (!res.second &&
- IsValidationEnabled(var_info->declaration->decorations,
+ IsValidationEnabled(decl->decorations,
ast::DisabledValidation::kBindingPointCollision) &&
IsValidationEnabled(res.first->second->decorations,
ast::DisabledValidation::kBindingPointCollision)) {
@@ -1744,13 +1772,13 @@
// variables in the resource interface of a given shader must not have
// the same group and binding values, when considered as a pair of
// values.
- auto func_name = builder_->Symbols().NameFor(info->declaration->symbol);
+ auto func_name = builder_->Symbols().NameFor(decl->symbol);
AddError("entry point '" + func_name +
"' references multiple variables that use the "
"same resource binding [[group(" +
std::to_string(bp.group) + "), binding(" +
std::to_string(bp.binding) + ")]]",
- var_info->declaration->source);
+ var_decl->source);
AddNote("first resource binding usage declared here",
res.first->second->source);
return false;
@@ -1760,19 +1788,16 @@
return true;
}
-bool Resolver::Function(const ast::Function* func) {
- auto* info = function_infos_.Create<FunctionInfo>(func);
-
- if (func->IsEntryPoint()) {
- entry_points_.emplace_back(info);
- }
-
- TINT_SCOPED_ASSIGNMENT(current_function_, info);
-
+sem::Function* Resolver::Function(const ast::Function* decl) {
variable_stack_.Push();
+ TINT_DEFER(variable_stack_.Pop());
+
uint32_t parameter_index = 0;
std::unordered_map<Symbol, Source> parameter_names;
- for (auto* param : func->params) {
+ std::vector<sem::Parameter*> parameters;
+
+ // Resolve all the parameters
+ for (auto* param : decl->params) {
Mark(param);
{ // Check the parameter name is unique for the function
@@ -1781,48 +1806,29 @@
auto name = builder_->Symbols().NameFor(param->symbol);
AddError("redefinition of parameter '" + name + "'", param->source);
AddNote("previous definition is here", emplaced.first->second);
- return false;
+ return nullptr;
}
}
- auto* param_info =
- Variable(param, VariableKind::kParameter, parameter_index++);
- if (!param_info) {
- return false;
+ auto* var = As<sem::Parameter>(
+ Variable(param, VariableKind::kParameter, parameter_index++));
+ if (!var) {
+ return nullptr;
}
for (auto* deco : param->decorations) {
Mark(deco);
}
if (!ValidateNoDuplicateDecorations(param->decorations)) {
- return false;
+ return nullptr;
}
- variable_stack_.Set(param->symbol, param_info);
- info->parameters.emplace_back(param_info);
+ variable_stack_.Set(param->symbol, var);
+ parameters.emplace_back(var);
- if (!ApplyStorageClassUsageToType(param->declared_storage_class,
- param_info->type, param->source)) {
- AddNote("while instantiating parameter " +
- builder_->Symbols().NameFor(param->symbol),
- param->source);
- return false;
- }
- if (auto* ptr = param_info->type->As<sem::Pointer>()) {
- // For MSL, we push module-scope variables into the entry point as pointer
- // parameters, so we also need to handle their store type.
- if (!ApplyStorageClassUsageToType(
- ptr->StorageClass(), const_cast<sem::Type*>(ptr->StoreType()),
- param->source)) {
- AddNote("while instantiating parameter " +
- builder_->Symbols().NameFor(param->symbol),
- param->source);
- return false;
- }
- }
-
- if (auto* str = param_info->type->As<sem::Struct>()) {
- switch (func->PipelineStage()) {
+ auto* var_ty = const_cast<sem::Type*>(var->Type());
+ if (auto* str = var_ty->As<sem::Struct>()) {
+ switch (decl->PipelineStage()) {
case ast::PipelineStage::kVertex:
str->AddUsage(sem::PipelineStageUsage::kVertexInput);
break;
@@ -1838,28 +1844,27 @@
}
}
- if (auto* ty = func->return_type) {
- info->return_type = Type(ty);
- info->return_type_name = ty->FriendlyName(builder_->Symbols());
- if (!info->return_type) {
- return false;
+ // Resolve the return type
+ sem::Type* return_type = nullptr;
+ if (auto* ty = decl->return_type) {
+ return_type = Type(ty);
+ if (!return_type) {
+ return nullptr;
}
} else {
- info->return_type = builder_->create<sem::Void>();
- info->return_type_name =
- info->return_type->FriendlyName(builder_->Symbols());
+ return_type = builder_->create<sem::Void>();
}
- if (auto* str = info->return_type->As<sem::Struct>()) {
+ if (auto* str = return_type->As<sem::Struct>()) {
if (!ApplyStorageClassUsageToType(ast::StorageClass::kNone, str,
- func->source)) {
+ decl->source)) {
AddNote("while instantiating return type for " +
- builder_->Symbols().NameFor(func->symbol),
- func->source);
- return false;
+ builder_->Symbols().NameFor(decl->symbol),
+ decl->source);
+ return nullptr;
}
- switch (func->PipelineStage()) {
+ switch (decl->PipelineStage()) {
case ast::PipelineStage::kVertex:
str->AddUsage(sem::PipelineStageUsage::kVertexOutput);
break;
@@ -1874,139 +1879,165 @@
}
}
- if (func->body) {
- Mark(func->body);
+ sem::WorkgroupSize ws{};
+ if (!WorkgroupSizeFor(decl, ws)) {
+ return nullptr;
+ }
+
+ auto* func =
+ builder_->create<sem::Function>(decl, return_type, parameters, ws);
+ builder_->Sem().Add(decl, func);
+
+ if (decl->IsEntryPoint()) {
+ entry_points_.emplace_back(func);
+ }
+
+ TINT_SCOPED_ASSIGNMENT(current_function_, func);
+
+ if (decl->body) {
+ Mark(decl->body);
if (current_compound_statement_) {
TINT_ICE(Resolver, diagnostics_)
<< "Resolver::Function() called with a current compound statement";
- return false;
+ return nullptr;
}
auto* sem_block = builder_->create<sem::FunctionBlockStatement>(func);
- builder_->Sem().Add(func->body, sem_block);
- if (!Scope(sem_block, [&] { return Statements(func->body->statements); })) {
- return false;
- }
- }
- variable_stack_.Pop();
-
- for (auto* deco : func->decorations) {
- Mark(deco);
- }
- if (!ValidateNoDuplicateDecorations(func->decorations)) {
- return false;
- }
-
- for (auto* deco : func->return_type_decorations) {
- Mark(deco);
- }
- if (!ValidateNoDuplicateDecorations(func->return_type_decorations)) {
- return false;
- }
-
- // Set work-group size defaults.
- for (int i = 0; i < 3; i++) {
- info->workgroup_size[i].value = 1;
- info->workgroup_size[i].overridable_const = nullptr;
- }
-
- if (auto* workgroup =
- ast::GetDecoration<ast::WorkgroupDecoration>(func->decorations)) {
- auto values = workgroup->Values();
- auto any_i32 = false;
- auto any_u32 = false;
- for (int i = 0; i < 3; i++) {
- // Each argument to this decoration can either be a literal, an
- // identifier for a module-scope constants, or nullptr if not specified.
-
- auto* expr = values[i];
- if (!expr) {
- // Not specified, just use the default.
- continue;
- }
-
- if (!Expression(expr)) {
- return false;
- }
-
- constexpr const char* kErrBadType =
- "workgroup_size argument must be either literal or module-scope "
- "constant of type i32 or u32";
- constexpr const char* kErrInconsistentType =
- "workgroup_size arguments must be of the same type, either i32 "
- "or u32";
-
- auto* ty = TypeOf(expr);
- bool is_i32 = ty->UnwrapRef()->Is<sem::I32>();
- bool is_u32 = ty->UnwrapRef()->Is<sem::U32>();
- if (!is_i32 && !is_u32) {
- AddError(kErrBadType, expr->source);
- return false;
- }
-
- any_i32 = any_i32 || is_i32;
- any_u32 = any_u32 || is_u32;
- if (any_i32 && any_u32) {
- AddError(kErrInconsistentType, expr->source);
- return false;
- }
-
- if (auto* ident = expr->As<ast::IdentifierExpression>()) {
- // We have an identifier of a module-scope constant.
- VariableInfo* var = variable_stack_.Get(ident->symbol);
- if (!var || !(var->declaration->is_const)) {
- AddError(kErrBadType, expr->source);
- return false;
- }
-
- // Capture the constant if an [[override]] attribute is present.
- if (ast::HasDecoration<ast::OverrideDecoration>(
- var->declaration->decorations)) {
- info->workgroup_size[i].overridable_const = var->declaration;
- }
-
- expr = var->declaration->constructor;
- if (!expr) {
- // No constructor means this value must be overriden by the user.
- info->workgroup_size[i].value = 0;
- continue;
- }
- } else if (!expr->Is<ast::ScalarConstructorExpression>()) {
- AddError(
- "workgroup_size argument must be either a literal or a "
- "module-scope constant",
- values[i]->source);
- return false;
- }
-
- auto val = ConstantValueOf(expr);
- if (!val) {
- TINT_ICE(Resolver, diagnostics_)
- << "could not resolve constant workgroup_size constant value";
- continue;
- }
- // Validate and set the default value for this dimension.
- if (is_i32 ? val.Elements()[0].i32 < 1 : val.Elements()[0].u32 < 1) {
- AddError("workgroup_size argument must be at least 1",
- values[i]->source);
- return false;
- }
-
- info->workgroup_size[i].value =
- is_i32 ? static_cast<uint32_t>(val.Elements()[0].i32)
- : val.Elements()[0].u32;
+ builder_->Sem().Add(decl->body, sem_block);
+ if (!Scope(sem_block, [&] { return Statements(decl->body->statements); })) {
+ return nullptr;
}
}
- if (!ValidateFunction(func, info)) {
- return false;
+ for (auto* deco : decl->decorations) {
+ Mark(deco);
+ }
+ if (!ValidateNoDuplicateDecorations(decl->decorations)) {
+ return nullptr;
+ }
+
+ for (auto* deco : decl->return_type_decorations) {
+ Mark(deco);
+ }
+ if (!ValidateNoDuplicateDecorations(decl->return_type_decorations)) {
+ return nullptr;
+ }
+
+ if (!ValidateFunction(func)) {
+ return nullptr;
}
// Register the function information _after_ processing the statements. This
// allows us to catch a function calling itself when determining the call
// information as this function doesn't exist until it's finished.
- symbol_to_function_[func->symbol] = info;
- function_to_info_.emplace(func, info);
+ symbol_to_function_[decl->symbol] = func;
+ // If this is an entry point, mark all transitively called functions as being
+ // used by this entry point.
+ if (decl->IsEntryPoint()) {
+ for (auto* f : func->TransitivelyCalledFunctions()) {
+ const_cast<sem::Function*>(f)->AddAncestorEntryPoint(func);
+ }
+ }
+
+ return func;
+}
+
+bool Resolver::WorkgroupSizeFor(const ast::Function* func,
+ sem::WorkgroupSize& ws) {
+ // Set work-group size defaults.
+ for (int i = 0; i < 3; i++) {
+ ws[i].value = 1;
+ ws[i].overridable_const = nullptr;
+ }
+
+ auto* deco = ast::GetDecoration<ast::WorkgroupDecoration>(func->decorations);
+ if (!deco) {
+ return true;
+ }
+
+ auto values = deco->Values();
+ auto any_i32 = false;
+ auto any_u32 = false;
+ for (int i = 0; i < 3; i++) {
+ // Each argument to this decoration can either be a literal, an
+ // identifier for a module-scope constants, or nullptr if not specified.
+
+ auto* expr = values[i];
+ if (!expr) {
+ // Not specified, just use the default.
+ continue;
+ }
+
+ auto* expr_sem = Expression(expr);
+ if (!expr_sem) {
+ return false;
+ }
+
+ constexpr const char* kErrBadType =
+ "workgroup_size argument must be either literal or module-scope "
+ "constant of type i32 or u32";
+ constexpr const char* kErrInconsistentType =
+ "workgroup_size arguments must be of the same type, either i32 "
+ "or u32";
+
+ auto* ty = TypeOf(expr);
+ bool is_i32 = ty->UnwrapRef()->Is<sem::I32>();
+ bool is_u32 = ty->UnwrapRef()->Is<sem::U32>();
+ if (!is_i32 && !is_u32) {
+ AddError(kErrBadType, expr->source);
+ return false;
+ }
+
+ any_i32 = any_i32 || is_i32;
+ any_u32 = any_u32 || is_u32;
+ if (any_i32 && any_u32) {
+ AddError(kErrInconsistentType, expr->source);
+ return false;
+ }
+
+ if (auto* ident = expr->As<ast::IdentifierExpression>()) {
+ // We have an identifier of a module-scope constant.
+ auto* var = variable_stack_.Get(ident->symbol);
+ if (!var || !var->Declaration()->is_const) {
+ AddError(kErrBadType, expr->source);
+ return false;
+ }
+
+ auto* decl = var->Declaration();
+ // Capture the constant if an [[override]] attribute is present.
+ if (ast::HasDecoration<ast::OverrideDecoration>(decl->decorations)) {
+ ws[i].overridable_const = decl;
+ }
+
+ expr = decl->constructor;
+ if (!expr) {
+ // No constructor means this value must be overriden by the user.
+ ws[i].value = 0;
+ continue;
+ }
+ } else if (!expr->Is<ast::ScalarConstructorExpression>()) {
+ AddError(
+ "workgroup_size argument must be either a literal or a "
+ "module-scope constant",
+ values[i]->source);
+ return false;
+ }
+
+ auto val = expr_sem->ConstantValue();
+ if (!val) {
+ TINT_ICE(Resolver, diagnostics_)
+ << "could not resolve constant workgroup_size constant value";
+ continue;
+ }
+ // Validate and set the default value for this dimension.
+ if (is_i32 ? val.Elements()[0].i32 < 1 : val.Elements()[0].u32 < 1) {
+ AddError("workgroup_size argument must be at least 1", values[i]->source);
+ return false;
+ }
+
+ ws[i].value = is_i32 ? static_cast<uint32_t>(val.Elements()[0].i32)
+ : val.Elements()[0].u32;
+ }
return true;
}
@@ -2080,8 +2111,8 @@
}
// Non-Compound statements
- sem::Statement* sem_statement =
- builder_->create<sem::Statement>(stmt, current_compound_statement_);
+ sem::Statement* sem_statement = builder_->create<sem::Statement>(
+ stmt, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt, sem_statement);
TINT_SCOPED_ASSIGNMENT(current_statement_, sem_statement);
if (auto* a = stmt->As<ast::AssignmentStatement>()) {
@@ -2100,9 +2131,6 @@
if (!Expression(c->expr)) {
return false;
}
- if (!ValidateCallStatement(c)) {
- return false;
- }
return true;
}
if (auto* c = stmt->As<ast::ContinueStatement>()) {
@@ -2158,7 +2186,7 @@
bool Resolver::CaseStatement(const ast::CaseStatement* stmt) {
auto* sem = builder_->create<sem::SwitchCaseBlockStatement>(
- stmt->body, current_compound_statement_);
+ stmt->body, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt, sem);
builder_->Sem().Add(stmt->body, sem);
Mark(stmt->body);
@@ -2169,8 +2197,8 @@
}
bool Resolver::IfStatement(const ast::IfStatement* stmt) {
- auto* sem =
- builder_->create<sem::IfStatement>(stmt, current_compound_statement_);
+ auto* sem = builder_->create<sem::IfStatement>(
+ stmt, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt, sem);
return Scope(sem, [&] {
if (!Expression(stmt->condition)) {
@@ -2179,15 +2207,15 @@
auto* cond_type = TypeOf(stmt->condition)->UnwrapRef();
if (!cond_type->Is<sem::Bool>()) {
- AddError("if statement condition must be bool, got " +
- cond_type->FriendlyName(builder_->Symbols()),
- stmt->condition->source);
+ AddError(
+ "if statement condition must be bool, got " + TypeNameOf(cond_type),
+ stmt->condition->source);
return false;
}
Mark(stmt->body);
auto* body = builder_->create<sem::BlockStatement>(
- stmt->body, current_compound_statement_);
+ stmt->body, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt->body, body);
if (!Scope(body, [&] { return Statements(stmt->body->statements); })) {
return false;
@@ -2204,8 +2232,8 @@
}
bool Resolver::ElseStatement(const ast::ElseStatement* stmt) {
- auto* sem =
- builder_->create<sem::ElseStatement>(stmt, current_compound_statement_);
+ auto* sem = builder_->create<sem::ElseStatement>(
+ stmt, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt, sem);
return Scope(sem, [&] {
if (auto* cond = stmt->condition) {
@@ -2216,7 +2244,7 @@
auto* else_cond_type = TypeOf(cond)->UnwrapRef();
if (!else_cond_type->Is<sem::Bool>()) {
AddError("else statement condition must be bool, got " +
- else_cond_type->FriendlyName(builder_->Symbols()),
+ TypeNameOf(else_cond_type),
cond->source);
return false;
}
@@ -2224,7 +2252,7 @@
Mark(stmt->body);
auto* body = builder_->create<sem::BlockStatement>(
- stmt->body, current_compound_statement_);
+ stmt->body, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt->body, body);
return Scope(body, [&] { return Statements(stmt->body->statements); });
});
@@ -2232,20 +2260,21 @@
bool Resolver::BlockStatement(const ast::BlockStatement* stmt) {
auto* sem = builder_->create<sem::BlockStatement>(
- stmt->As<ast::BlockStatement>(), current_compound_statement_);
+ stmt->As<ast::BlockStatement>(), current_compound_statement_,
+ current_function_);
builder_->Sem().Add(stmt, sem);
return Scope(sem, [&] { return Statements(stmt->statements); });
}
bool Resolver::LoopStatement(const ast::LoopStatement* stmt) {
- auto* sem =
- builder_->create<sem::LoopStatement>(stmt, current_compound_statement_);
+ auto* sem = builder_->create<sem::LoopStatement>(
+ stmt, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt, sem);
return Scope(sem, [&] {
Mark(stmt->body);
auto* body = builder_->create<sem::LoopBlockStatement>(
- stmt->body, current_compound_statement_);
+ stmt->body, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt->body, body);
return Scope(body, [&] {
if (!Statements(stmt->body->statements)) {
@@ -2256,7 +2285,8 @@
if (!stmt->continuing->Empty()) {
auto* continuing =
builder_->create<sem::LoopContinuingBlockStatement>(
- stmt->continuing, current_compound_statement_);
+ stmt->continuing, current_compound_statement_,
+ current_function_);
builder_->Sem().Add(stmt->continuing, continuing);
if (!Scope(continuing, [&] {
return Statements(stmt->continuing->statements);
@@ -2272,7 +2302,7 @@
bool Resolver::ForLoopStatement(const ast::ForLoopStatement* stmt) {
auto* sem = builder_->create<sem::ForLoopStatement>(
- stmt, current_compound_statement_);
+ stmt, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt, sem);
return Scope(sem, [&] {
if (auto* initializer = stmt->initializer) {
@@ -2287,10 +2317,10 @@
return false;
}
- if (!TypeOf(condition)->UnwrapRef()->Is<sem::Bool>()) {
- AddError(
- "for-loop condition must be bool, got " + TypeNameOf(condition),
- condition->source);
+ auto* cond_ty = TypeOf(condition)->UnwrapRef();
+ if (!cond_ty->Is<sem::Bool>()) {
+ AddError("for-loop condition must be bool, got " + TypeNameOf(cond_ty),
+ condition->source);
return false;
}
}
@@ -2305,13 +2335,13 @@
Mark(stmt->body);
auto* body = builder_->create<sem::LoopBlockStatement>(
- stmt->body, current_compound_statement_);
+ stmt->body, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt->body, body);
return Scope(body, [&] { return Statements(stmt->body->statements); });
});
}
-bool Resolver::Expression(const ast::Expression* root) {
+sem::Expression* Resolver::Expression(const ast::Expression* root) {
std::vector<const ast::Expression*> sorted;
if (!ast::TraverseExpressions<ast::TraverseOrder::RightToLeft>(
root, diagnostics_, [&](const ast::Expression* expr) {
@@ -2319,145 +2349,241 @@
sorted.emplace_back(expr);
return ast::TraverseAction::Descend;
})) {
- return false;
+ return nullptr;
}
for (auto* expr : utils::Reverse(sorted)) {
- bool ok = false;
+ sem::Expression* sem_expr = nullptr;
if (auto* array = expr->As<ast::ArrayAccessorExpression>()) {
- ok = ArrayAccessor(array);
+ sem_expr = ArrayAccessor(array);
} else if (auto* bin_op = expr->As<ast::BinaryExpression>()) {
- ok = Binary(bin_op);
+ sem_expr = Binary(bin_op);
} else if (auto* bitcast = expr->As<ast::BitcastExpression>()) {
- ok = Bitcast(bitcast);
+ sem_expr = Bitcast(bitcast);
} else if (auto* call = expr->As<ast::CallExpression>()) {
- ok = Call(call);
+ sem_expr = Call(call);
} else if (auto* ctor = expr->As<ast::ConstructorExpression>()) {
- ok = Constructor(ctor);
+ sem_expr = Constructor(ctor);
} else if (auto* ident = expr->As<ast::IdentifierExpression>()) {
- ok = Identifier(ident);
+ sem_expr = Identifier(ident);
} else if (auto* member = expr->As<ast::MemberAccessorExpression>()) {
- ok = MemberAccessor(member);
+ sem_expr = MemberAccessor(member);
} else if (auto* unary = expr->As<ast::UnaryOpExpression>()) {
- ok = UnaryOp(unary);
+ sem_expr = UnaryOp(unary);
} else if (expr->Is<ast::PhonyExpression>()) {
- ok = true; // No-op
+ sem_expr = builder_->create<sem::Expression>(
+ expr, builder_->create<sem::Void>(), current_statement_,
+ sem::Constant{});
} else {
TINT_ICE(Resolver, diagnostics_)
<< "unhandled expression type: " << expr->TypeInfo().name;
- return false;
+ return nullptr;
}
- if (!ok) {
- return false;
+ if (!sem_expr) {
+ return nullptr;
+ }
+ builder_->Sem().Add(expr, sem_expr);
+ if (expr == root) {
+ return sem_expr;
}
}
- return true;
+ TINT_ICE(Resolver, diagnostics_) << "Expression() did not find root node";
+ return nullptr;
}
-bool Resolver::ArrayAccessor(const ast::ArrayAccessorExpression* expr) {
+sem::Expression* Resolver::ArrayAccessor(
+ const ast::ArrayAccessorExpression* expr) {
auto* idx = expr->index;
- auto* res = TypeOf(expr->array);
- auto* parent_type = res->UnwrapRef();
- const sem::Type* ret = nullptr;
- if (auto* arr = parent_type->As<sem::Array>()) {
- ret = arr->ElemType();
- } else if (auto* vec = parent_type->As<sem::Vector>()) {
- ret = vec->type();
- } else if (auto* mat = parent_type->As<sem::Matrix>()) {
- ret = builder_->create<sem::Vector>(mat->type(), mat->rows());
+ auto* parent_raw_ty = TypeOf(expr->array);
+ auto* parent_ty = parent_raw_ty->UnwrapRef();
+ const sem::Type* ty = nullptr;
+ if (auto* arr = parent_ty->As<sem::Array>()) {
+ ty = arr->ElemType();
+ } else if (auto* vec = parent_ty->As<sem::Vector>()) {
+ ty = vec->type();
+ } else if (auto* mat = parent_ty->As<sem::Matrix>()) {
+ ty = builder_->create<sem::Vector>(mat->type(), mat->rows());
} else {
- AddError("cannot index type '" + TypeNameOf(expr->array) + "'",
- expr->source);
- return false;
+ AddError("cannot index type '" + TypeNameOf(parent_ty) + "'", expr->source);
+ return nullptr;
}
- if (!TypeOf(idx)->UnwrapRef()->IsAnyOf<sem::I32, sem::U32>()) {
+ auto* idx_ty = TypeOf(idx)->UnwrapRef();
+ if (!idx_ty->IsAnyOf<sem::I32, sem::U32>()) {
AddError("index must be of type 'i32' or 'u32', found: '" +
- TypeNameOf(idx) + "'",
+ TypeNameOf(idx_ty) + "'",
idx->source);
- return false;
+ return nullptr;
}
- if (parent_type->Is<sem::Array>() || parent_type->Is<sem::Matrix>()) {
- if (!res->Is<sem::Reference>()) {
+ if (parent_ty->IsAnyOf<sem::Array, sem::Matrix>()) {
+ if (!parent_raw_ty->Is<sem::Reference>()) {
// TODO(bclayton): expand this to allow any const_expr expression
// https://github.com/gpuweb/gpuweb/issues/1272
auto* scalar = idx->As<ast::ScalarConstructorExpression>();
if (!scalar || !scalar->literal->As<ast::IntLiteral>()) {
AddError("index must be signed or unsigned integer literal",
idx->source);
- return false;
+ return nullptr;
}
}
}
// If we're extracting from a reference, we return a reference.
- if (auto* ref = res->As<sem::Reference>()) {
- ret = builder_->create<sem::Reference>(ret, ref->StorageClass(),
- ref->Access());
+ if (auto* ref = parent_raw_ty->As<sem::Reference>()) {
+ ty = builder_->create<sem::Reference>(ty, ref->StorageClass(),
+ ref->Access());
}
- SetExprInfo(expr, ret);
- return true;
+ auto val = EvaluateConstantValue(expr, ty);
+ return builder_->create<sem::Expression>(expr, ty, current_statement_, val);
}
-bool Resolver::Bitcast(const ast::BitcastExpression* expr) {
+sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) {
auto* ty = Type(expr->type);
if (!ty) {
- return false;
+ return nullptr;
}
if (ty->Is<sem::Pointer>()) {
AddError("cannot cast to a pointer", expr->source);
- return false;
+ return nullptr;
}
- SetExprInfo(expr, ty, expr->type->FriendlyName(builder_->Symbols()));
- return true;
+
+ auto val = EvaluateConstantValue(expr, ty);
+ return builder_->create<sem::Expression>(expr, ty, current_statement_, val);
}
-bool Resolver::Call(const ast::CallExpression* call) {
- Mark(call->func);
- auto* ident = call->func;
+sem::Expression* Resolver::Call(const ast::CallExpression* expr) {
+ auto* ident = expr->func;
+ Mark(ident);
auto name = builder_->Symbols().NameFor(ident->symbol);
auto intrinsic_type = sem::ParseIntrinsicType(name);
- if (intrinsic_type != IntrinsicType::kNone) {
- if (!IntrinsicCall(call, intrinsic_type)) {
- return false;
+ auto* call = (intrinsic_type != IntrinsicType::kNone)
+ ? IntrinsicCall(expr, intrinsic_type)
+ : FunctionCall(expr);
+
+ current_function_->AddDirectCall(call);
+ return call;
+}
+
+sem::Call* Resolver::IntrinsicCall(const ast::CallExpression* expr,
+ sem::IntrinsicType intrinsic_type) {
+ std::vector<const sem::Expression*> args(expr->args.size());
+ std::vector<const sem::Type*> arg_tys(expr->args.size());
+ for (size_t i = 0; i < expr->args.size(); i++) {
+ auto* arg = Sem(expr->args[i]);
+ if (!arg) {
+ return nullptr;
}
- } else {
- if (!FunctionCall(call)) {
- return false;
+ args[i] = arg;
+ arg_tys[i] = arg->Type();
+ }
+
+ auto* intrinsic = intrinsic_table_->Lookup(intrinsic_type, std::move(arg_tys),
+ expr->source);
+ if (!intrinsic) {
+ return nullptr;
+ }
+
+ if (intrinsic->IsDeprecated()) {
+ AddWarning("use of deprecated intrinsic", expr->source);
+ }
+
+ auto* call = builder_->create<sem::Call>(expr, intrinsic, std::move(args),
+ current_statement_);
+
+ current_function_->AddDirectlyCalledIntrinsic(intrinsic);
+
+ if (IsTextureIntrinsic(intrinsic_type) &&
+ !ValidateTextureIntrinsicFunction(call)) {
+ return nullptr;
+ }
+
+ if (!ValidateCall(call)) {
+ return nullptr;
+ }
+
+ return call;
+}
+
+sem::Call* Resolver::FunctionCall(const ast::CallExpression* expr) {
+ auto* ident = expr->func;
+ auto name = builder_->Symbols().NameFor(ident->symbol);
+
+ auto target_it = symbol_to_function_.find(ident->symbol);
+ if (target_it == symbol_to_function_.end()) {
+ if (current_function_ &&
+ current_function_->Declaration()->symbol == ident->symbol) {
+ AddError("recursion is not permitted. '" + name +
+ "' attempted to call itself.",
+ expr->source);
+ } else {
+ AddError("unable to find called function: " + name, expr->source);
+ }
+ return nullptr;
+ }
+ auto* target = target_it->second;
+
+ std::vector<const sem::Expression*> args(expr->args.size());
+ for (size_t i = 0; i < expr->args.size(); i++) {
+ auto* arg = Sem(expr->args[i]);
+ if (!arg) {
+ return nullptr;
+ }
+ args[i] = arg;
+ }
+
+ auto* call = builder_->create<sem::Call>(expr, target, std::move(args),
+ current_statement_);
+
+ if (current_function_) {
+ target->AddCallSite(call);
+
+ // Note: Requires called functions to be resolved first.
+ // This is currently guaranteed as functions must be declared before
+ // use.
+ current_function_->AddTransitivelyCalledFunction(target);
+ for (auto* transitive_call : target->TransitivelyCalledFunctions()) {
+ current_function_->AddTransitivelyCalledFunction(transitive_call);
+ }
+
+ // We inherit any referenced variables from the callee.
+ for (auto* var : target->TransitivelyReferencedGlobals()) {
+ current_function_->AddTransitivelyReferencedGlobal(var);
}
}
- return ValidateCall(call);
+ if (!ValidateFunctionCall(call)) {
+ return nullptr;
+ }
+
+ if (!ValidateCall(call)) {
+ return nullptr;
+ }
+
+ return call;
}
-bool Resolver::ValidateCall(const ast::CallExpression* call) {
- if (TypeOf(call)->Is<sem::Void>()) {
+bool Resolver::ValidateCall(const sem::Call* call) {
+ if (call->Type()->Is<sem::Void>()) {
bool is_call_statement = false;
- if (current_statement_) {
- if (auto* call_stmt =
- As<ast::CallStatement>(current_statement_->Declaration())) {
- if (call_stmt->expr == call) {
- is_call_statement = true;
- }
+ if (auto* call_stmt = As<ast::CallStatement>(call->Stmt()->Declaration())) {
+ if (call_stmt->expr == call->Declaration()) {
+ is_call_statement = true;
}
}
if (!is_call_statement) {
// https://gpuweb.github.io/gpuweb/wgsl/#function-call-expr
// If the called function does not return a value, a function call
// statement should be used instead.
- auto* ident = call->func;
+ auto* ident = call->Declaration()->func;
auto name = builder_->Symbols().NameFor(ident->symbol);
- // A function call is made to either a user declared function or an
- // intrinsic. function_calls_ only maps CallExpression to user declared
- // functions
- bool is_function = function_calls_.count(call) != 0;
+ bool is_function = call->Target()->Is<sem::Function>();
AddError((is_function ? "function" : "intrinsic") + std::string(" '") +
name + "' does not return a value",
- call->source);
+ call->Declaration()->source);
return false;
}
}
@@ -2465,47 +2591,8 @@
return true;
}
-bool Resolver::ValidateCallStatement(const ast::CallStatement*) {
- return true;
-}
-
-bool Resolver::IntrinsicCall(const ast::CallExpression* call,
- sem::IntrinsicType intrinsic_type) {
- std::vector<const sem::Type*> arg_tys;
- arg_tys.reserve(call->args.size());
- for (auto* expr : call->args) {
- arg_tys.emplace_back(TypeOf(expr));
- }
-
- auto* result =
- intrinsic_table_->Lookup(intrinsic_type, arg_tys, call->source);
- if (!result) {
- return false;
- }
-
- if (result->IsDeprecated()) {
- AddWarning("use of deprecated intrinsic", call->source);
- }
-
- auto* out = builder_->create<sem::Call>(call, result, current_statement_);
- builder_->Sem().Add(call, out);
- SetExprInfo(call, result->ReturnType());
-
- current_function_->intrinsic_calls.emplace_back(
- IntrinsicCallInfo{call, result});
-
- if (IsTextureIntrinsic(intrinsic_type) &&
- !ValidateTextureIntrinsicFunction(call, out)) {
- return false;
- }
-
- return true;
-}
-
-bool Resolver::ValidateTextureIntrinsicFunction(
- const ast::CallExpression* ast_call,
- const sem::Call* sem_call) {
- auto* intrinsic = sem_call->Target()->As<sem::Intrinsic>();
+bool Resolver::ValidateTextureIntrinsicFunction(const sem::Call* call) {
+ auto* intrinsic = call->Target()->As<sem::Intrinsic>();
if (!intrinsic) {
return false;
}
@@ -2513,146 +2600,111 @@
auto& signature = intrinsic->Signature();
auto index = signature.IndexOf(sem::ParameterUsage::kOffset);
if (index > -1) {
- auto* param = ast_call->args[index];
- if (param->Is<ast::TypeConstructorExpression>()) {
- auto values = ConstantValueOf(param);
- if (!values.IsValid()) {
- AddError(
- "'" + func_name + "' offset parameter must be a const_expression",
- param->source);
- return false;
- }
+ auto* arg = call->Arguments()[index];
+ if (auto values = arg->ConstantValue()) {
+ // Assert that the constant values are of the expected type.
if (!values.Type()->Is<sem::Vector>() ||
!values.ElementType()->Is<sem::I32>()) {
TINT_ICE(Resolver, diagnostics_)
<< "failed to resolve '" + func_name + "' offset parameter type";
return false;
}
- for (auto offset : values.Elements()) {
- auto offset_value = offset.i32;
- if (offset_value < -8 || offset_value > 7) {
- AddError("each offset component of '" + func_name +
- "' must be at least -8 and at most 7. "
- "found: '" +
- std::to_string(offset_value) + "'",
- param->source);
- return false;
+
+ // Currently const_expr is restricted to literals and type constructors.
+ // Check that that's all we have for the offset parameter.
+ bool is_const_expr = true;
+ ast::TraverseExpressions(
+ arg->Declaration(), diagnostics_, [&](const ast::Expression* e) {
+ if (e->IsAnyOf<ast::ScalarConstructorExpression,
+ ast::TypeConstructorExpression>()) {
+ return ast::TraverseAction::Descend;
+ }
+ is_const_expr = false;
+ return ast::TraverseAction::Stop;
+ });
+ if (is_const_expr) {
+ for (auto offset : values.Elements()) {
+ auto offset_value = offset.i32;
+ if (offset_value < -8 || offset_value > 7) {
+ AddError("each offset component of '" + func_name +
+ "' must be at least -8 and at most 7. "
+ "found: '" +
+ std::to_string(offset_value) + "'",
+ arg->Declaration()->source);
+ return false;
+ }
}
+ return true;
}
- } else {
- AddError(
- "'" + func_name + "' offset parameter must be a const_expression",
- param->source);
- return false;
}
- }
- return true;
-}
-
-bool Resolver::FunctionCall(const ast::CallExpression* call) {
- auto* ident = call->func;
- auto name = builder_->Symbols().NameFor(ident->symbol);
-
- auto callee_func_it = symbol_to_function_.find(ident->symbol);
- if (callee_func_it == symbol_to_function_.end()) {
- if (current_function_ &&
- current_function_->declaration->symbol == ident->symbol) {
- AddError("recursion is not permitted. '" + name +
- "' attempted to call itself.",
- call->source);
- } else {
- AddError("unable to find called function: " + name, call->source);
- }
- return false;
- }
- auto* callee_func = callee_func_it->second;
-
- if (current_function_) {
- callee_func->callsites.push_back(call);
-
- // Note: Requires called functions to be resolved first.
- // This is currently guaranteed as functions must be declared before
- // use.
- current_function_->transitive_calls.add(callee_func);
- for (auto* transitive_call : callee_func->transitive_calls) {
- current_function_->transitive_calls.add(transitive_call);
- }
-
- // We inherit any referenced variables from the callee.
- for (auto* var : callee_func->referenced_module_vars) {
- set_referenced_from_function_if_needed(var, false);
- }
- }
-
- function_calls_.emplace(call,
- FunctionCallInfo{callee_func, current_statement_});
- SetExprInfo(call, callee_func->return_type, callee_func->return_type_name);
-
- if (!ValidateFunctionCall(call, callee_func)) {
+ AddError("'" + func_name + "' offset parameter must be a const_expression",
+ arg->Declaration()->source);
return false;
}
return true;
}
-bool Resolver::ValidateFunctionCall(const ast::CallExpression* call,
- const FunctionInfo* target) {
- auto* ident = call->func;
+bool Resolver::ValidateFunctionCall(const sem::Call* call) {
+ auto* decl = call->Declaration();
+ auto* ident = decl->func;
+ auto* target = call->Target()->As<sem::Function>();
auto name = builder_->Symbols().NameFor(ident->symbol);
- if (target->declaration->IsEntryPoint()) {
+ if (target->Declaration()->IsEntryPoint()) {
// https://www.w3.org/TR/WGSL/#function-restriction
// An entry point must never be the target of a function call.
AddError("entry point functions cannot be the target of a function call",
- call->source);
+ decl->source);
return false;
}
- if (call->args.size() != target->parameters.size()) {
- bool more = call->args.size() > target->parameters.size();
+ if (decl->args.size() != target->Parameters().size()) {
+ bool more = decl->args.size() > target->Parameters().size();
AddError("too " + (more ? std::string("many") : std::string("few")) +
" arguments in call to '" + name + "', expected " +
- std::to_string(target->parameters.size()) + ", got " +
- std::to_string(call->args.size()),
- call->source);
+ std::to_string(target->Parameters().size()) + ", got " +
+ std::to_string(call->Arguments().size()),
+ decl->source);
return false;
}
- for (size_t i = 0; i < call->args.size(); ++i) {
- const VariableInfo* param = target->parameters[i];
- const ast::Expression* arg_expr = call->args[i];
+ for (size_t i = 0; i < call->Arguments().size(); ++i) {
+ const sem::Variable* param = target->Parameters()[i];
+ const ast::Expression* arg_expr = decl->args[i];
+ auto* param_type = param->Type();
auto* arg_type = TypeOf(arg_expr)->UnwrapRef();
- if (param->type != arg_type) {
+ if (param_type != arg_type) {
AddError("type mismatch for argument " + std::to_string(i + 1) +
" in call to '" + name + "', expected '" +
- param->type->FriendlyName(builder_->Symbols()) + "', got '" +
- arg_type->FriendlyName(builder_->Symbols()) + "'",
+ TypeNameOf(param_type) + "', got '" + TypeNameOf(arg_type) +
+ "'",
arg_expr->source);
return false;
}
- if (param->declaration->type->Is<ast::Pointer>()) {
+ if (param_type->Is<sem::Pointer>()) {
auto is_valid = false;
if (auto* ident_expr = arg_expr->As<ast::IdentifierExpression>()) {
- VariableInfo* var = variable_stack_.Get(ident_expr->symbol);
+ auto* var = variable_stack_.Get(ident_expr->symbol);
if (!var) {
TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier";
return false;
}
- if (var->kind == VariableKind::kParameter) {
+ if (var->Is<sem::Parameter>()) {
is_valid = true;
}
} else if (auto* unary = arg_expr->As<ast::UnaryOpExpression>()) {
if (unary->op == ast::UnaryOp::kAddressOf) {
if (auto* ident_unary =
unary->expr->As<ast::IdentifierExpression>()) {
- VariableInfo* var = variable_stack_.Get(ident_unary->symbol);
+ auto* var = variable_stack_.Get(ident_unary->symbol);
if (!var) {
TINT_ICE(Resolver, diagnostics_)
<< "failed to resolve identifier";
return false;
}
- if (var->declaration->is_const) {
+ if (var->Declaration()->is_const) {
TINT_ICE(Resolver, diagnostics_)
<< "Resolver::FunctionCall() encountered an address-of "
"expression of a constant identifier expression";
@@ -2665,7 +2717,7 @@
if (!is_valid &&
IsValidationEnabled(
- param->declaration->decorations,
+ param->Declaration()->decorations,
ast::DisabledValidation::kIgnoreInvalidPointerArgument)) {
AddError(
"expected an address-of expression of a variable identifier "
@@ -2678,52 +2730,52 @@
return true;
}
-bool Resolver::Constructor(const ast::ConstructorExpression* expr) {
+sem::Expression* Resolver::Constructor(const ast::ConstructorExpression* expr) {
if (auto* type_ctor = expr->As<ast::TypeConstructorExpression>()) {
- auto* type = Type(type_ctor->type);
- if (!type) {
- return false;
+ auto* ty = Type(type_ctor->type);
+ if (!ty) {
+ return nullptr;
}
- auto type_name = type_ctor->type->FriendlyName(builder_->Symbols());
-
// Now that the argument types have been determined, make sure that they
// obey the constructor type rules laid out in
// https://gpuweb.github.io/gpuweb/wgsl.html#type-constructor-expr.
bool ok = true;
- if (auto* vec_type = type->As<sem::Vector>()) {
- ok = ValidateVectorConstructor(type_ctor, vec_type, type_name);
- } else if (auto* mat_type = type->As<sem::Matrix>()) {
- ok = ValidateMatrixConstructor(type_ctor, mat_type, type_name);
- } else if (type->is_scalar()) {
- ok = ValidateScalarConstructor(type_ctor, type, type_name);
- } else if (auto* arr_type = type->As<sem::Array>()) {
+ if (auto* vec_type = ty->As<sem::Vector>()) {
+ ok = ValidateVectorConstructor(type_ctor, vec_type);
+ } else if (auto* mat_type = ty->As<sem::Matrix>()) {
+ ok = ValidateMatrixConstructor(type_ctor, mat_type);
+ } else if (ty->is_scalar()) {
+ ok = ValidateScalarConstructor(type_ctor, ty);
+ } else if (auto* arr_type = ty->As<sem::Array>()) {
ok = ValidateArrayConstructor(type_ctor, arr_type);
- } else if (auto* struct_type = type->As<sem::Struct>()) {
+ } else if (auto* struct_type = ty->As<sem::Struct>()) {
ok = ValidateStructureConstructor(type_ctor, struct_type);
} else {
AddError("type is not constructible", type_ctor->source);
- return false;
+ return nullptr;
}
if (!ok) {
- return false;
+ return nullptr;
}
- SetExprInfo(expr, type, type_name);
- return true;
+
+ auto val = EvaluateConstantValue(expr, ty);
+ return builder_->create<sem::Expression>(expr, ty, current_statement_, val);
}
if (auto* scalar_ctor = expr->As<ast::ScalarConstructorExpression>()) {
Mark(scalar_ctor->literal);
- auto* type = TypeOf(scalar_ctor->literal);
- if (!type) {
- return false;
+ auto* ty = TypeOf(scalar_ctor->literal);
+ if (!ty) {
+ return nullptr;
}
- SetExprInfo(expr, type);
- return true;
+
+ auto val = EvaluateConstantValue(expr, ty);
+ return builder_->create<sem::Expression>(expr, ty, current_statement_, val);
}
TINT_ICE(Resolver, diagnostics_) << "unexpected constructor expression type";
- return false;
+ return nullptr;
}
bool Resolver::ValidateStructureConstructor(
@@ -2746,12 +2798,13 @@
}
for (auto* member : struct_type->Members()) {
auto* value = ctor->values[member->Index()];
- if (member->Type() != TypeOf(value)->UnwrapRef()) {
+ auto* value_ty = TypeOf(value);
+ if (member->Type() != value_ty->UnwrapRef()) {
AddError(
"type in struct constructor does not match struct member type: "
"expected '" +
- member->Type()->FriendlyName(builder_->Symbols()) +
- "', found '" + TypeNameOf(value) + "'",
+ TypeNameOf(member->Type()) + "', found '" +
+ TypeNameOf(value_ty) + "'",
value->source);
return false;
}
@@ -2764,15 +2817,14 @@
const ast::TypeConstructorExpression* ctor,
const sem::Array* array_type) {
auto& values = ctor->values;
- auto* elem_type = array_type->ElemType();
+ auto* elem_ty = array_type->ElemType();
for (auto* value : values) {
- auto* value_type = TypeOf(value)->UnwrapRef();
- if (value_type != elem_type) {
+ auto* value_ty = TypeOf(value)->UnwrapRef();
+ if (value_ty != elem_ty) {
AddError(
"type in array constructor does not match array type: "
"expected '" +
- elem_type->FriendlyName(builder_->Symbols()) + "', found '" +
- TypeNameOf(value) + "'",
+ TypeNameOf(elem_ty) + "', found '" + TypeNameOf(value_ty) + "'",
value->source);
return false;
}
@@ -2781,7 +2833,7 @@
if (array_type->IsRuntimeSized()) {
AddError("cannot init a runtime-sized array", ctor->source);
return false;
- } else if (!elem_type->IsConstructible()) {
+ } else if (!elem_ty->IsConstructible()) {
AddError("array constructor has non-constructible element type",
ctor->type->As<ast::Array>()->type->source);
return false;
@@ -2804,36 +2856,34 @@
bool Resolver::ValidateVectorConstructor(
const ast::TypeConstructorExpression* ctor,
- const sem::Vector* vec_type,
- const std::string& type_name) {
+ const sem::Vector* vec_type) {
auto& values = ctor->values;
- auto* elem_type = vec_type->type();
+ auto* elem_ty = vec_type->type();
size_t value_cardinality_sum = 0;
for (auto* value : values) {
- auto* value_type = TypeOf(value)->UnwrapRef();
- if (value_type->is_scalar()) {
- if (elem_type != value_type) {
+ auto* value_ty = TypeOf(value)->UnwrapRef();
+ if (value_ty->is_scalar()) {
+ if (elem_ty != value_ty) {
AddError(
"type in vector constructor does not match vector type: "
"expected '" +
- elem_type->FriendlyName(builder_->Symbols()) + "', found '" +
- TypeNameOf(value) + "'",
+ TypeNameOf(elem_ty) + "', found '" + TypeNameOf(value_ty) + "'",
value->source);
return false;
}
value_cardinality_sum++;
- } else if (auto* value_vec = value_type->As<sem::Vector>()) {
- auto* value_elem_type = value_vec->type();
+ } else if (auto* value_vec = value_ty->As<sem::Vector>()) {
+ auto* value_elem_ty = value_vec->type();
// A mismatch of vector type parameter T is only an error if multiple
// arguments are present. A single argument constructor constitutes a
// type conversion expression.
- if (elem_type != value_elem_type && values.size() > 1u) {
+ if (elem_ty != value_elem_ty && values.size() > 1u) {
AddError(
"type in vector constructor does not match vector type: "
"expected '" +
- elem_type->FriendlyName(builder_->Symbols()) + "', found '" +
- value_elem_type->FriendlyName(builder_->Symbols()) + "'",
+ TypeNameOf(elem_ty) + "', found '" + TypeNameOf(value_elem_ty) +
+ "'",
value->source);
return false;
}
@@ -2842,7 +2892,7 @@
} else {
// A vector constructor can only accept vectors and scalars.
AddError("expected vector or scalar type in vector constructor; found: " +
- value_type->FriendlyName(builder_->Symbols()),
+ TypeNameOf(value_ty),
value->source);
return false;
}
@@ -2858,7 +2908,7 @@
}
const Source& values_start = values[0]->source;
const Source& values_end = values[values.size() - 1]->source;
- AddError("attempted to construct '" + type_name + "' with " +
+ AddError("attempted to construct '" + TypeNameOf(vec_type) + "' with " +
std::to_string(value_cardinality_sum) + " component(s)",
Source::Combine(values_start, values_end));
return false;
@@ -2885,27 +2935,27 @@
bool Resolver::ValidateMatrixConstructor(
const ast::TypeConstructorExpression* ctor,
- const sem::Matrix* matrix_type,
- const std::string& type_name) {
+ const sem::Matrix* matrix_ty) {
auto& values = ctor->values;
// Zero Value expression
if (values.empty()) {
return true;
}
- if (!ValidateMatrix(matrix_type, ctor->source)) {
+ if (!ValidateMatrix(matrix_ty, ctor->source)) {
return false;
}
- auto* elem_type = matrix_type->type();
- auto num_elements = matrix_type->columns() * matrix_type->rows();
+ auto* elem_type = matrix_ty->type();
+ auto num_elements = matrix_ty->columns() * matrix_ty->rows();
// Print a generic error for an invalid matrix constructor, showing the
// available overloads.
auto print_error = [&]() {
const Source& values_start = values[0]->source;
const Source& values_end = values[values.size() - 1]->source;
- auto elem_type_name = elem_type->FriendlyName(builder_->Symbols());
+ auto type_name = TypeNameOf(matrix_ty);
+ auto elem_type_name = TypeNameOf(elem_type);
std::stringstream ss;
ss << "invalid constructor for " + type_name << std::endl << std::endl;
ss << "3 candidates available:" << std::endl;
@@ -2914,11 +2964,11 @@
<< elem_type_name << ")"
<< " // " << std::to_string(num_elements) << " arguments" << std::endl;
ss << " " << type_name << "(";
- for (uint32_t c = 0; c < matrix_type->columns(); c++) {
+ for (uint32_t c = 0; c < matrix_ty->columns(); c++) {
if (c > 0) {
ss << ", ";
}
- ss << VectorPretty(matrix_type->rows(), elem_type);
+ ss << VectorPretty(matrix_ty->rows(), elem_type);
}
ss << ")" << std::endl;
AddError(ss.str(), Source::Combine(values_start, values_end));
@@ -2927,10 +2977,10 @@
const sem::Type* expected_arg_type = nullptr;
if (num_elements == values.size()) {
// Column-major construction from scalar elements.
- expected_arg_type = matrix_type->type();
- } else if (matrix_type->columns() == values.size()) {
+ expected_arg_type = matrix_ty->type();
+ } else if (matrix_ty->columns() == values.size()) {
// Column-by-column construction from vectors.
- expected_arg_type = matrix_type->ColumnType();
+ expected_arg_type = matrix_ty->ColumnType();
} else {
print_error();
return false;
@@ -2948,8 +2998,7 @@
bool Resolver::ValidateScalarConstructor(
const ast::TypeConstructorExpression* ctor,
- const sem::Type* type,
- const std::string& type_name) {
+ const sem::Type* ty) {
if (ctor->values.size() == 0) {
return true;
}
@@ -2962,20 +3011,20 @@
// Validate constructor
auto* value = ctor->values[0];
- auto* value_type = TypeOf(value)->UnwrapRef();
+ auto* value_ty = TypeOf(value)->UnwrapRef();
using Bool = sem::Bool;
using I32 = sem::I32;
using U32 = sem::U32;
using F32 = sem::F32;
- const bool is_valid = (type->Is<Bool>() && value_type->is_scalar()) ||
- (type->Is<I32>() && value_type->is_scalar()) ||
- (type->Is<U32>() && value_type->is_scalar()) ||
- (type->Is<F32>() && value_type->is_scalar());
+ const bool is_valid = (ty->Is<Bool>() && value_ty->is_scalar()) ||
+ (ty->Is<I32>() && value_ty->is_scalar()) ||
+ (ty->Is<U32>() && value_ty->is_scalar()) ||
+ (ty->Is<F32>() && value_ty->is_scalar());
if (!is_valid) {
- AddError("cannot construct '" + type_name + "' with a value of type '" +
- TypeNameOf(value) + "'",
+ AddError("cannot construct '" + TypeNameOf(ty) +
+ "' with a value of type '" + TypeNameOf(value_ty) + "'",
ctor->source);
return false;
@@ -2984,13 +3033,11 @@
return true;
}
-bool Resolver::Identifier(const ast::IdentifierExpression* expr) {
+sem::Expression* Resolver::Identifier(const ast::IdentifierExpression* expr) {
auto symbol = expr->symbol;
- if (VariableInfo* var = variable_stack_.Get(symbol)) {
- SetExprInfo(expr, var->type, var->type_name);
-
- var->users.push_back(expr);
- set_referenced_from_function_if_needed(var, true);
+ if (auto* var = variable_stack_.Get(symbol)) {
+ auto* user =
+ builder_->create<sem::VariableUser>(expr, current_statement_, var);
if (current_statement_) {
// If identifier is part of a loop continuing block, make sure it
@@ -3021,40 +3068,47 @@
AddNote("identifier '" + builder_->Symbols().NameFor(symbol) +
"' referenced in continuing block here",
expr->source);
- return false;
+ return nullptr;
}
}
}
}
}
- return true;
+ if (current_function_) {
+ if (auto* global = var->As<sem::GlobalVariable>()) {
+ current_function_->AddDirectlyReferencedGlobal(global);
+ }
+ }
+
+ var->AddUser(user);
+ return user;
}
- auto iter = symbol_to_function_.find(symbol);
- if (iter != symbol_to_function_.end()) {
+ if (symbol_to_function_.count(symbol)) {
AddError("missing '(' for function call", expr->source.End());
- return false;
+ return nullptr;
}
std::string name = builder_->Symbols().NameFor(symbol);
if (sem::ParseIntrinsicType(name) != IntrinsicType::kNone) {
AddError("missing '(' for intrinsic call", expr->source.End());
- return false;
+ return nullptr;
}
AddError("identifier must be declared before use: " + name, expr->source);
- return false;
+ return nullptr;
}
-bool Resolver::MemberAccessor(const ast::MemberAccessorExpression* expr) {
+sem::Expression* Resolver::MemberAccessor(
+ const ast::MemberAccessorExpression* expr) {
auto* structure = TypeOf(expr->structure);
- auto* storage_type = structure->UnwrapRef();
+ auto* storage_ty = structure->UnwrapRef();
const sem::Type* ret = nullptr;
std::vector<uint32_t> swizzle;
- if (auto* str = storage_type->As<sem::Struct>()) {
+ if (auto* str = storage_ty->As<sem::Struct>()) {
Mark(expr->member);
auto symbol = expr->member->symbol;
@@ -3071,7 +3125,7 @@
AddError(
"struct member " + builder_->Symbols().NameFor(symbol) + " not found",
expr->source);
- return false;
+ return nullptr;
}
// If we're extracting from a reference, we return a reference.
@@ -3080,9 +3134,11 @@
ref->Access());
}
- builder_->Sem().Add(expr, builder_->create<sem::StructMemberAccess>(
- expr, ret, current_statement_, member));
- } else if (auto* vec = storage_type->As<sem::Vector>()) {
+ return builder_->create<sem::StructMemberAccess>(
+ expr, ret, current_statement_, member);
+ }
+
+ if (auto* vec = storage_ty->As<sem::Vector>()) {
Mark(expr->member);
std::string s = builder_->Symbols().NameFor(expr->member->symbol);
auto size = s.size();
@@ -3109,18 +3165,18 @@
default:
AddError("invalid vector swizzle character",
expr->member->source.Begin() + swizzle.size());
- return false;
+ return nullptr;
}
if (swizzle.back() >= vec->Width()) {
AddError("invalid vector swizzle member", expr->member->source);
- return false;
+ return nullptr;
}
}
if (size < 1 || size > 4) {
AddError("invalid vector swizzle size", expr->member->source);
- return false;
+ return nullptr;
}
// All characters are valid, check if they're being mixed
@@ -3134,7 +3190,7 @@
!std::all_of(s.begin(), s.end(), is_xyzw)) {
AddError("invalid mixing of vector swizzle characters rgba with xyzw",
expr->member->source);
- return false;
+ return nullptr;
}
if (size == 1) {
@@ -3151,23 +3207,18 @@
ret = builder_->create<sem::Vector>(vec->type(),
static_cast<uint32_t>(size));
}
- builder_->Sem().Add(
- expr, builder_->create<sem::Swizzle>(expr, ret, current_statement_,
- std::move(swizzle)));
- } else {
- AddError(
- "invalid member accessor expression. Expected vector or struct, got '" +
- TypeNameOf(expr->structure) + "'",
- expr->structure->source);
- return false;
+ return builder_->create<sem::Swizzle>(expr, ret, current_statement_,
+ std::move(swizzle));
}
- SetExprInfo(expr, ret);
-
- return true;
+ AddError(
+ "invalid member accessor expression. Expected vector or struct, got '" +
+ TypeNameOf(storage_ty) + "'",
+ expr->structure->source);
+ return nullptr;
}
-bool Resolver::Binary(const ast::BinaryExpression* expr) {
+sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
using Bool = sem::Bool;
using F32 = sem::F32;
using I32 = sem::I32;
@@ -3175,12 +3226,12 @@
using Matrix = sem::Matrix;
using Vector = sem::Vector;
- auto* lhs_type = TypeOf(expr->lhs)->UnwrapRef();
- auto* rhs_type = TypeOf(expr->rhs)->UnwrapRef();
+ auto* lhs_ty = TypeOf(expr->lhs)->UnwrapRef();
+ auto* rhs_ty = TypeOf(expr->rhs)->UnwrapRef();
- auto* lhs_vec = lhs_type->As<Vector>();
+ auto* lhs_vec = lhs_ty->As<Vector>();
auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
- auto* rhs_vec = rhs_type->As<Vector>();
+ auto* rhs_vec = rhs_ty->As<Vector>();
auto* rhs_vec_elem_type = rhs_vec ? rhs_vec->type() : nullptr;
const bool matching_vec_elem_types =
@@ -3188,70 +3239,66 @@
(lhs_vec_elem_type == rhs_vec_elem_type) &&
(lhs_vec->Width() == rhs_vec->Width());
- const bool matching_types = matching_vec_elem_types || (lhs_type == rhs_type);
+ const bool matching_types = matching_vec_elem_types || (lhs_ty == rhs_ty);
+
+ auto build = [&](const sem::Type* ty) {
+ auto val = EvaluateConstantValue(expr, ty);
+ return builder_->create<sem::Expression>(expr, ty, current_statement_, val);
+ };
// Binary logical expressions
if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
- if (matching_types && lhs_type->Is<Bool>()) {
- SetExprInfo(expr, lhs_type);
- return true;
+ if (matching_types && lhs_ty->Is<Bool>()) {
+ return build(lhs_ty);
}
}
if (expr->IsOr() || expr->IsAnd()) {
- if (matching_types && lhs_type->Is<Bool>()) {
- SetExprInfo(expr, lhs_type);
- return true;
+ if (matching_types && lhs_ty->Is<Bool>()) {
+ return build(lhs_ty);
}
if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->Is<Bool>()) {
- SetExprInfo(expr, lhs_type);
- return true;
+ return build(lhs_ty);
}
}
// Arithmetic expressions
if (expr->IsArithmetic()) {
// Binary arithmetic expressions over scalars
- if (matching_types && lhs_type->is_numeric_scalar()) {
- SetExprInfo(expr, lhs_type);
- return true;
+ if (matching_types && lhs_ty->is_numeric_scalar()) {
+ return build(lhs_ty);
}
// Binary arithmetic expressions over vectors
if (matching_types && lhs_vec_elem_type &&
lhs_vec_elem_type->is_numeric_scalar()) {
- SetExprInfo(expr, lhs_type);
- return true;
+ return build(lhs_ty);
}
// Binary arithmetic expressions with mixed scalar and vector operands
- if (lhs_vec_elem_type && (lhs_vec_elem_type == rhs_type)) {
+ if (lhs_vec_elem_type && (lhs_vec_elem_type == rhs_ty)) {
if (expr->IsModulo()) {
- if (rhs_type->is_integer_scalar()) {
- SetExprInfo(expr, lhs_type);
- return true;
+ if (rhs_ty->is_integer_scalar()) {
+ return build(lhs_ty);
}
- } else if (rhs_type->is_numeric_scalar()) {
- SetExprInfo(expr, lhs_type);
- return true;
+ } else if (rhs_ty->is_numeric_scalar()) {
+ return build(lhs_ty);
}
}
- if (rhs_vec_elem_type && (rhs_vec_elem_type == lhs_type)) {
+ if (rhs_vec_elem_type && (rhs_vec_elem_type == lhs_ty)) {
if (expr->IsModulo()) {
- if (lhs_type->is_integer_scalar()) {
- SetExprInfo(expr, rhs_type);
- return true;
+ if (lhs_ty->is_integer_scalar()) {
+ return build(rhs_ty);
}
- } else if (lhs_type->is_numeric_scalar()) {
- SetExprInfo(expr, rhs_type);
- return true;
+ } else if (lhs_ty->is_numeric_scalar()) {
+ return build(rhs_ty);
}
}
}
// Matrix arithmetic
- auto* lhs_mat = lhs_type->As<Matrix>();
+ auto* lhs_mat = lhs_ty->As<Matrix>();
auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
- auto* rhs_mat = rhs_type->As<Matrix>();
+ auto* rhs_mat = rhs_ty->As<Matrix>();
auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
// Addition and subtraction of float matrices
if ((expr->IsAdd() || expr->IsSubtract()) && lhs_mat_elem_type &&
@@ -3259,49 +3306,42 @@
rhs_mat_elem_type->Is<F32>() &&
(lhs_mat->columns() == rhs_mat->columns()) &&
(lhs_mat->rows() == rhs_mat->rows())) {
- SetExprInfo(expr, rhs_type);
- return true;
+ return build(rhs_ty);
}
if (expr->IsMultiply()) {
// Multiplication of a matrix and a scalar
- if (lhs_type->Is<F32>() && rhs_mat_elem_type &&
+ if (lhs_ty->Is<F32>() && rhs_mat_elem_type &&
rhs_mat_elem_type->Is<F32>()) {
- SetExprInfo(expr, rhs_type);
- return true;
+ return build(rhs_ty);
}
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
- rhs_type->Is<F32>()) {
- SetExprInfo(expr, lhs_type);
- return true;
+ rhs_ty->Is<F32>()) {
+ return build(lhs_ty);
}
// Vector times matrix
if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
(lhs_vec->Width() == rhs_mat->rows())) {
- SetExprInfo(expr, builder_->create<sem::Vector>(lhs_vec->type(),
- rhs_mat->columns()));
- return true;
+ return build(
+ builder_->create<sem::Vector>(lhs_vec->type(), rhs_mat->columns()));
}
// Matrix times vector
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>() &&
(lhs_mat->columns() == rhs_vec->Width())) {
- SetExprInfo(expr, builder_->create<sem::Vector>(rhs_vec->type(),
- lhs_mat->rows()));
- return true;
+ return build(
+ builder_->create<sem::Vector>(rhs_vec->type(), lhs_mat->rows()));
}
// Matrix times matrix
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
(lhs_mat->columns() == rhs_mat->rows())) {
- SetExprInfo(expr, builder_->create<sem::Matrix>(
- builder_->create<sem::Vector>(lhs_mat_elem_type,
- lhs_mat->rows()),
- rhs_mat->columns()));
- return true;
+ return build(builder_->create<sem::Matrix>(
+ builder_->create<sem::Vector>(lhs_mat_elem_type, lhs_mat->rows()),
+ rhs_mat->columns()));
}
}
@@ -3309,15 +3349,13 @@
if (expr->IsComparison()) {
if (matching_types) {
// Special case for bools: only == and !=
- if (lhs_type->Is<Bool>() && (expr->IsEqual() || expr->IsNotEqual())) {
- SetExprInfo(expr, builder_->create<sem::Bool>());
- return true;
+ if (lhs_ty->Is<Bool>() && (expr->IsEqual() || expr->IsNotEqual())) {
+ return build(builder_->create<sem::Bool>());
}
// For the rest, we can compare i32, u32, and f32
- if (lhs_type->IsAnyOf<I32, U32, F32>()) {
- SetExprInfo(expr, builder_->create<sem::Bool>());
- return true;
+ if (lhs_ty->IsAnyOf<I32, U32, F32>()) {
+ return build(builder_->create<sem::Bool>());
}
}
@@ -3325,24 +3363,21 @@
if (matching_vec_elem_types) {
if (lhs_vec_elem_type->Is<Bool>() &&
(expr->IsEqual() || expr->IsNotEqual())) {
- SetExprInfo(expr, builder_->create<sem::Vector>(
- builder_->create<sem::Bool>(), lhs_vec->Width()));
- return true;
+ return build(builder_->create<sem::Vector>(
+ builder_->create<sem::Bool>(), lhs_vec->Width()));
}
if (lhs_vec_elem_type->is_numeric_scalar()) {
- SetExprInfo(expr, builder_->create<sem::Vector>(
- builder_->create<sem::Bool>(), lhs_vec->Width()));
- return true;
+ return build(builder_->create<sem::Vector>(
+ builder_->create<sem::Bool>(), lhs_vec->Width()));
}
}
}
// Binary bitwise operations
if (expr->IsBitwise()) {
- if (matching_types && lhs_type->is_integer_scalar_or_vector()) {
- SetExprInfo(expr, lhs_type);
- return true;
+ if (matching_types && lhs_ty->is_integer_scalar_or_vector()) {
+ return build(lhs_ty);
}
}
@@ -3352,79 +3387,72 @@
// differences in computation rules (i.e. right shift can be arithmetic or
// logical depending on lhs type).
- if (lhs_type->IsAnyOf<I32, U32>() && rhs_type->Is<U32>()) {
- SetExprInfo(expr, lhs_type);
- return true;
+ if (lhs_ty->IsAnyOf<I32, U32>() && rhs_ty->Is<U32>()) {
+ return build(lhs_ty);
}
if (lhs_vec_elem_type && lhs_vec_elem_type->IsAnyOf<I32, U32>() &&
rhs_vec_elem_type && rhs_vec_elem_type->Is<U32>()) {
- SetExprInfo(expr, lhs_type);
- return true;
+ return build(lhs_ty);
}
}
AddError("Binary expression operand types are invalid for this operation: " +
- lhs_type->FriendlyName(builder_->Symbols()) + " " +
- FriendlyName(expr->op) + " " +
- rhs_type->FriendlyName(builder_->Symbols()),
+ TypeNameOf(lhs_ty) + " " + FriendlyName(expr->op) + " " +
+ TypeNameOf(rhs_ty),
expr->source);
- return false;
+ return nullptr;
}
-bool Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
- auto* expr_type = TypeOf(unary->expr);
- if (!expr_type) {
- return false;
+sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
+ auto* expr_ty = TypeOf(unary->expr);
+ if (!expr_ty) {
+ return nullptr;
}
- std::string type_name;
- const sem::Type* type = nullptr;
+ const sem::Type* ty = nullptr;
switch (unary->op) {
case ast::UnaryOp::kNot:
// Result type matches the deref'd inner type.
- type_name = TypeNameOf(unary->expr);
- type = expr_type->UnwrapRef();
- if (!type->Is<sem::Bool>() && !type->is_bool_vector()) {
- AddError("cannot logical negate expression of type '" +
- TypeNameOf(unary->expr),
- unary->expr->source);
- return false;
+ ty = expr_ty->UnwrapRef();
+ if (!ty->Is<sem::Bool>() && !ty->is_bool_vector()) {
+ AddError(
+ "cannot logical negate expression of type '" + TypeNameOf(expr_ty),
+ unary->expr->source);
+ return nullptr;
}
break;
case ast::UnaryOp::kComplement:
// Result type matches the deref'd inner type.
- type_name = TypeNameOf(unary->expr);
- type = expr_type->UnwrapRef();
- if (!type->is_integer_scalar_or_vector()) {
+ ty = expr_ty->UnwrapRef();
+ if (!ty->is_integer_scalar_or_vector()) {
AddError("cannot bitwise complement expression of type '" +
- TypeNameOf(unary->expr),
+ TypeNameOf(expr_ty),
unary->expr->source);
- return false;
+ return nullptr;
}
break;
case ast::UnaryOp::kNegation:
// Result type matches the deref'd inner type.
- type_name = TypeNameOf(unary->expr);
- type = expr_type->UnwrapRef();
- if (!(type->IsAnyOf<sem::F32, sem::I32>() ||
- type->is_signed_integer_vector() || type->is_float_vector())) {
- AddError("cannot negate expression of type '" + TypeNameOf(unary->expr),
+ ty = expr_ty->UnwrapRef();
+ if (!(ty->IsAnyOf<sem::F32, sem::I32>() ||
+ ty->is_signed_integer_vector() || ty->is_float_vector())) {
+ AddError("cannot negate expression of type '" + TypeNameOf(expr_ty),
unary->expr->source);
- return false;
+ return nullptr;
}
break;
case ast::UnaryOp::kAddressOf:
- if (auto* ref = expr_type->As<sem::Reference>()) {
+ if (auto* ref = expr_ty->As<sem::Reference>()) {
if (ref->StoreType()->UnwrapRef()->is_handle()) {
AddError(
"cannot take the address of expression in handle storage class",
unary->expr->source);
- return false;
+ return nullptr;
}
auto* array = unary->expr->As<ast::ArrayAccessorExpression>();
@@ -3434,48 +3462,48 @@
TypeOf(member->structure)->UnwrapRef()->Is<sem::Vector>())) {
AddError("cannot take the address of a vector component",
unary->expr->source);
- return false;
+ return nullptr;
}
- type = builder_->create<sem::Pointer>(
- ref->StoreType(), ref->StorageClass(), ref->Access());
+ ty = builder_->create<sem::Pointer>(ref->StoreType(),
+ ref->StorageClass(), ref->Access());
} else {
AddError("cannot take the address of expression", unary->expr->source);
- return false;
+ return nullptr;
}
break;
case ast::UnaryOp::kIndirection:
- if (auto* ptr = expr_type->As<sem::Pointer>()) {
- type = builder_->create<sem::Reference>(
+ if (auto* ptr = expr_ty->As<sem::Pointer>()) {
+ ty = builder_->create<sem::Reference>(
ptr->StoreType(), ptr->StorageClass(), ptr->Access());
} else {
AddError("cannot dereference expression of type '" +
- TypeNameOf(unary->expr) + "'",
+ TypeNameOf(expr_ty) + "'",
unary->expr->source);
- return false;
+ return nullptr;
}
break;
}
- SetExprInfo(unary, type);
- return true;
+ auto val = EvaluateConstantValue(unary, ty);
+ return builder_->create<sem::Expression>(unary, ty, current_statement_, val);
}
bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
- const ast::Variable* var = stmt->variable;
- Mark(var);
+ Mark(stmt->variable);
- if (!ValidateNoDuplicateDefinition(var->symbol, var->source)) {
+ if (!ValidateNoDuplicateDefinition(stmt->variable->symbol,
+ stmt->variable->source)) {
return false;
}
- auto* info = Variable(var, VariableKind::kLocal);
- if (!info) {
+ auto* var = Variable(stmt->variable, VariableKind::kLocal);
+ if (!var) {
return false;
}
- for (auto* deco : var->decorations) {
+ for (auto* deco : stmt->variable->decorations) {
Mark(deco);
if (!deco->Is<ast::InternalDecoration>()) {
AddError("decorations are not valid on local variables", deco->source);
@@ -3483,38 +3511,12 @@
}
}
- variable_stack_.Set(var->symbol, info);
+ variable_stack_.Set(stmt->variable->symbol, var);
if (current_block_) { // Not all statements are inside a block
- current_block_->AddDecl(var);
+ current_block_->AddDecl(stmt->variable);
}
- if (!ValidateVariable(info)) {
- return false;
- }
-
- if (!var->is_const &&
- IsValidationEnabled(var->decorations,
- ast::DisabledValidation::kIgnoreStorageClass)) {
- if (!info->type->UnwrapRef()->IsConstructible()) {
- AddError("function variable must have a constructible type",
- var->type ? var->type->source : var->source);
- return false;
- }
- if (info->storage_class != ast::StorageClass::kFunction) {
- if (info->storage_class != ast::StorageClass::kNone) {
- AddError("function variable has a non-function storage class",
- stmt->source);
- return false;
- }
- info->storage_class = ast::StorageClass::kFunction;
- }
- }
-
- if (!ApplyStorageClassUsageToType(info->storage_class, info->type,
- var->source)) {
- AddNote("while instantiating variable " +
- builder_->Symbols().NameFor(var->symbol),
- var->source);
+ if (!ValidateVariable(var)) {
return false;
}
@@ -3563,19 +3565,16 @@
}
sem::Type* Resolver::TypeOf(const ast::Expression* expr) {
- auto it = expr_info_.find(expr);
- if (it != expr_info_.end()) {
- return const_cast<sem::Type*>(it->second.type);
- }
- return nullptr;
+ auto* sem = Sem(expr);
+ return sem ? const_cast<sem::Type*>(sem->Type()) : nullptr;
}
-std::string Resolver::TypeNameOf(const ast::Expression* expr) {
- auto it = expr_info_.find(expr);
- if (it != expr_info_.end()) {
- return it->second.type_name;
- }
- return "";
+std::string Resolver::TypeNameOf(const sem::Type* ty) {
+ return RawTypeNameOf(ty->UnwrapRef());
+}
+
+std::string Resolver::RawTypeNameOf(const sem::Type* ty) {
+ return ty->FriendlyName(builder_->Symbols());
}
sem::Type* Resolver::TypeOf(const ast::Literal* lit) {
@@ -3596,56 +3595,37 @@
return nullptr;
}
-void Resolver::SetExprInfo(const ast::Expression* expr,
- const sem::Type* type,
- std::string type_name) {
- if (expr_info_.count(expr)) {
- TINT_ICE(Resolver, diagnostics_)
- << "SetExprInfo() called twice for the same expression";
- }
- if (type_name.empty()) {
- type_name = type->FriendlyName(builder_->Symbols());
- }
- auto constant_value = EvaluateConstantValue(expr, type);
- expr_info_.emplace(
- expr, ExpressionInfo{type, std::move(type_name), current_statement_,
- std::move(constant_value)});
-}
-
bool Resolver::ValidatePipelineStages() {
- auto check_workgroup_storage = [&](FunctionInfo* func,
- FunctionInfo* entry_point) {
- auto stage = entry_point->declaration->PipelineStage();
+ auto check_workgroup_storage = [&](const sem::Function* func,
+ const sem::Function* entry_point) {
+ auto stage = entry_point->Declaration()->PipelineStage();
if (stage != ast::PipelineStage::kCompute) {
- for (auto* var : func->local_referenced_module_vars) {
- if (var->storage_class == ast::StorageClass::kWorkgroup) {
+ for (auto* var : func->DirectlyReferencedGlobals()) {
+ if (var->StorageClass() == ast::StorageClass::kWorkgroup) {
std::stringstream stage_name;
stage_name << stage;
- for (auto* user : var->users) {
- auto it = expr_info_.find(user->As<ast::Expression>());
- if (it != expr_info_.end()) {
- if (func->declaration->symbol ==
- it->second.statement->Function()->symbol) {
- AddError("workgroup memory cannot be used by " +
- stage_name.str() + " pipeline stage",
- user->source);
- break;
- }
+ for (auto* user : var->Users()) {
+ if (func == user->Stmt()->Function()) {
+ AddError("workgroup memory cannot be used by " +
+ stage_name.str() + " pipeline stage",
+ user->Declaration()->source);
+ break;
}
}
- AddNote("variable is declared here", var->declaration->source);
+ AddNote("variable is declared here", var->Declaration()->source);
if (func != entry_point) {
- TraverseCallChain(entry_point, func, [&](FunctionInfo* f) {
- AddNote("called by function '" +
- builder_->Symbols().NameFor(f->declaration->symbol) +
- "'",
- f->declaration->source);
+ TraverseCallChain(entry_point, func, [&](const sem::Function* f) {
+ AddNote(
+ "called by function '" +
+ builder_->Symbols().NameFor(f->Declaration()->symbol) +
+ "'",
+ f->Declaration()->source);
});
AddNote("called by entry point '" +
builder_->Symbols().NameFor(
- entry_point->declaration->symbol) +
+ entry_point->Declaration()->symbol) +
"'",
- entry_point->declaration->source);
+ entry_point->Declaration()->source);
}
return false;
}
@@ -3658,33 +3638,35 @@
if (!check_workgroup_storage(entry_point, entry_point)) {
return false;
}
- for (auto* func : entry_point->transitive_calls) {
+ for (auto* func : entry_point->TransitivelyCalledFunctions()) {
if (!check_workgroup_storage(func, entry_point)) {
return false;
}
}
}
- auto check_intrinsic_calls = [&](FunctionInfo* func,
- FunctionInfo* entry_point) {
- auto stage = entry_point->declaration->PipelineStage();
- for (auto& call : func->intrinsic_calls) {
- if (!call.intrinsic->SupportedStages().Contains(stage)) {
+ auto check_intrinsic_calls = [&](const sem::Function* func,
+ const sem::Function* entry_point) {
+ auto stage = entry_point->Declaration()->PipelineStage();
+ for (auto* intrinsic : func->DirectlyCalledIntrinsics()) {
+ if (!intrinsic->SupportedStages().Contains(stage)) {
+ auto* call = func->FindDirectCallTo(intrinsic);
std::stringstream err;
err << "built-in cannot be used by " << stage << " pipeline stage";
- AddError(err.str(), call.call->source);
+ AddError(err.str(), call ? call->Declaration()->source
+ : func->Declaration()->source);
if (func != entry_point) {
- TraverseCallChain(entry_point, func, [&](FunctionInfo* f) {
+ TraverseCallChain(entry_point, func, [&](const sem::Function* f) {
AddNote("called by function '" +
- builder_->Symbols().NameFor(f->declaration->symbol) +
+ builder_->Symbols().NameFor(f->Declaration()->symbol) +
"'",
- f->declaration->source);
+ f->Declaration()->source);
});
AddNote("called by entry point '" +
builder_->Symbols().NameFor(
- entry_point->declaration->symbol) +
+ entry_point->Declaration()->symbol) +
"'",
- entry_point->declaration->source);
+ entry_point->Declaration()->source);
}
return false;
}
@@ -3696,7 +3678,7 @@
if (!check_intrinsic_calls(entry_point, entry_point)) {
return false;
}
- for (auto* func : entry_point->transitive_calls) {
+ for (auto* func : entry_point->TransitivelyCalledFunctions()) {
if (!check_intrinsic_calls(func, entry_point)) {
return false;
}
@@ -3706,15 +3688,15 @@
}
template <typename CALLBACK>
-void Resolver::TraverseCallChain(FunctionInfo* from,
- FunctionInfo* to,
+void Resolver::TraverseCallChain(const sem::Function* from,
+ const sem::Function* to,
CALLBACK&& callback) const {
- for (auto* f : from->transitive_calls) {
+ for (auto* f : from->TransitivelyCalledFunctions()) {
if (f == to) {
callback(f);
return;
}
- if (f->transitive_calls.contains(to)) {
+ if (f->TransitivelyCalledFunctions().contains(to)) {
TraverseCallChain(f, to, callback);
callback(f);
return;
@@ -3724,127 +3706,6 @@
<< "TraverseCallChain() 'from' does not transitively call 'to'";
}
-void Resolver::CreateSemanticNodes() const {
- auto& sem = builder_->Sem();
-
- // Collate all the 'ancestor_entry_points' - this is a map of function
- // symbol to all the entry points that transitively call the function.
- std::unordered_map<Symbol, std::vector<Symbol>> ancestor_entry_points;
- for (auto* entry_point : entry_points_) {
- for (auto* call : entry_point->transitive_calls) {
- auto& vec = ancestor_entry_points[call->declaration->symbol];
- vec.emplace_back(entry_point->declaration->symbol);
- }
- }
-
- // Create semantic nodes for all ast::Variables
- std::unordered_map<const tint::ast::Variable*, sem::Parameter*> sem_params;
- for (auto it : variable_to_info_) {
- auto* var = it.first;
- auto* info = it.second;
-
- sem::Variable* sem_var = nullptr;
-
- if (ast::HasDecoration<ast::OverrideDecoration>(var->decorations)) {
- // Create a pipeline overridable constant.
- sem_var = builder_->create<sem::GlobalVariable>(var, info->type,
- info->constant_id);
- } else {
- switch (info->kind) {
- case VariableKind::kGlobal:
- sem_var = builder_->create<sem::GlobalVariable>(
- var, info->type, info->storage_class, info->access,
- info->binding_point);
- break;
- case VariableKind::kLocal:
- sem_var = builder_->create<sem::LocalVariable>(
- var, info->type, info->storage_class, info->access);
- break;
- case VariableKind::kParameter: {
- auto* param = builder_->create<sem::Parameter>(
- var, info->index, info->type, info->storage_class, info->access);
- sem_var = param;
- sem_params.emplace(var, param);
- break;
- }
- }
- }
-
- std::vector<const sem::VariableUser*> users;
- for (auto* user : info->users) {
- // Create semantic node for the identifier expression if necessary
- auto* sem_expr = sem.Get(user);
- if (sem_expr == nullptr) {
- auto& expr_info = expr_info_.at(user);
- auto* type = expr_info.type;
- auto* stmt = expr_info.statement;
- auto* sem_user = builder_->create<sem::VariableUser>(
- user, type, stmt, sem_var, expr_info.constant_value);
- sem_var->AddUser(sem_user);
- sem.Add(user, sem_user);
- } else {
- auto* sem_user = sem_expr->As<sem::VariableUser>();
- if (!sem_user) {
- TINT_ICE(Resolver, diagnostics_) << "expected sem::VariableUser, got "
- << sem_expr->TypeInfo().name;
- }
- sem_var->AddUser(sem_user);
- }
- }
- sem.Add(var, sem_var);
- }
-
- auto remap_vars = [&sem](const std::vector<VariableInfo*>& in) {
- std::vector<const sem::GlobalVariable*> out;
- out.reserve(in.size());
- for (auto* info : in) {
- out.emplace_back(sem.Get<sem::GlobalVariable>(info->declaration));
- }
- return out;
- };
-
- // Create semantic nodes for all ast::Functions
- std::unordered_map<FunctionInfo*, sem::Function*> func_info_to_sem_func;
- for (auto it : function_to_info_) {
- auto* func = it.first;
- auto* info = it.second;
-
- std::vector<sem::Parameter*> parameters;
- parameters.reserve(info->parameters.size());
- for (auto* p : info->parameters) {
- parameters.emplace_back(sem_params.at(p->declaration));
- }
-
- auto* sem_func = builder_->create<sem::Function>(
- info->declaration, info->return_type, parameters,
- remap_vars(info->referenced_module_vars),
- remap_vars(info->local_referenced_module_vars), info->callsites,
- ancestor_entry_points[func->symbol], info->workgroup_size);
- func_info_to_sem_func.emplace(info, sem_func);
- sem.Add(func, sem_func);
- }
-
- // Create semantic nodes for all ast::CallExpressions
- for (auto it : function_calls_) {
- auto* call = it.first;
- auto info = it.second;
- auto* sem_func = func_info_to_sem_func.at(info.function);
- sem.Add(call, builder_->create<sem::Call>(call, sem_func, info.statement));
- }
-
- // Create semantic nodes for all remaining expression types
- for (auto it : expr_info_) {
- auto* expr = it.first;
- auto& info = it.second;
- if (sem.Get(expr)) {
- // Expression has already been assigned a semantic node
- continue;
- }
- sem.Add(expr, builder_->create<sem::Expression>(
- expr, info.type, info.statement, info.constant_value));
- }
-}
-
sem::Array* Resolver::Array(const ast::Array* arr) {
auto source = arr->source;
@@ -3854,7 +3715,7 @@
}
if (!IsPlain(elem_type)) { // Check must come before GetDefaultAlignAndSize()
- AddError(elem_type->FriendlyName(builder_->Symbols()) +
+ AddError(TypeNameOf(elem_type) +
" cannot be used as an element type of an array",
source);
return nullptr;
@@ -3892,13 +3753,14 @@
// sem::Array uses a size of 0 for a runtime-sized array.
uint32_t count = 0;
if (auto* count_expr = arr->count) {
- if (!Expression(count_expr)) {
+ auto* count_sem = Expression(count_expr);
+ if (!count_sem) {
return nullptr;
}
auto size_source = count_expr->source;
- auto* ty = TypeOf(count_expr)->UnwrapRef();
+ auto* ty = count_sem->Type()->UnwrapRef();
if (!ty->is_integer_scalar()) {
AddError("array size must be integer scalar", size_source);
return nullptr;
@@ -3906,21 +3768,21 @@
if (auto* ident = count_expr->As<ast::IdentifierExpression>()) {
// Make sure the identifier is a non-overridable module-scope constant.
- VariableInfo* var = variable_stack_.Get(ident->symbol);
- if (!var || var->kind != VariableKind::kGlobal ||
- !var->declaration->is_const) {
+ auto* var = variable_stack_.Get(ident->symbol);
+ if (!var || !var->Is<sem::GlobalVariable>() ||
+ !var->Declaration()->is_const) {
AddError("array size identifier must be a module-scope constant",
size_source);
return nullptr;
}
if (ast::HasDecoration<ast::OverrideDecoration>(
- var->declaration->decorations)) {
+ var->Declaration()->decorations)) {
AddError("array size expression must not be pipeline-overridable",
size_source);
return nullptr;
}
- count_expr = var->declaration->constructor;
+ count_expr = var->Declaration()->constructor;
} else if (!count_expr->Is<ast::ScalarConstructorExpression>()) {
AddError(
"array size expression must be either a literal or a module-scope "
@@ -3929,7 +3791,7 @@
return nullptr;
}
- auto count_val = ConstantValueOf(count_expr);
+ auto count_val = count_sem->ConstantValue();
if (!count_val) {
TINT_ICE(Resolver, diagnostics_)
<< "could not resolve array size expression";
@@ -4128,7 +3990,7 @@
const Source& source,
const bool is_input) {
std::string inputs_or_output = is_input ? "inputs" : "output";
- if (current_function_ && current_function_->declaration->PipelineStage() ==
+ if (current_function_ && current_function_->Declaration()->PipelineStage() ==
ast::PipelineStage::kCompute) {
AddError("decoration is not valid for compute shader " + inputs_or_output,
location->source);
@@ -4136,7 +3998,7 @@
}
if (!type->is_numeric_scalar_or_vector()) {
- std::string invalid_type = type->FriendlyName(builder_->Symbols());
+ std::string invalid_type = TypeNameOf(type);
AddError("cannot apply 'location' attribute to declaration of type '" +
invalid_type + "'",
source);
@@ -4201,7 +4063,7 @@
// Validate member type
if (!IsPlain(type)) {
- AddError(type->FriendlyName(builder_->Symbols()) +
+ AddError(TypeNameOf(type) +
" cannot be used as the type of a structure member",
member->source);
return nullptr;
@@ -4323,7 +4185,7 @@
}
bool Resolver::ValidateReturn(const ast::ReturnStatement* ret) {
- auto* func_type = current_function_->return_type;
+ auto* func_type = current_function_->ReturnType();
auto* ret_type = ret->value ? TypeOf(ret->value)->UnwrapRef()
: builder_->create<sem::Void>();
@@ -4332,13 +4194,13 @@
AddError(
"return statement type must match its function "
"return type, returned '" +
- ret_type->FriendlyName(builder_->Symbols()) + "', expected '" +
- current_function_->return_type_name + "'",
+ TypeNameOf(ret_type) + "', expected '" + TypeNameOf(func_type) +
+ "'",
ret->source);
return false;
}
- auto* sem = builder_->Sem().Get(ret);
+ auto* sem = Sem(ret);
if (auto* continuing =
sem->FindFirstParent<sem::LoopContinuingBlockStatement>()) {
AddError("continuing blocks must not contain a return statement",
@@ -4353,8 +4215,6 @@
}
bool Resolver::Return(const ast::ReturnStatement* ret) {
- current_function_->return_statements.push_back(ret);
-
if (auto* value = ret->value) {
if (!Expression(value)) {
return false;
@@ -4435,8 +4295,8 @@
}
bool Resolver::SwitchStatement(const ast::SwitchStatement* stmt) {
- auto* sem =
- builder_->create<sem::SwitchStatement>(stmt, current_compound_statement_);
+ auto* sem = builder_->create<sem::SwitchStatement>(
+ stmt, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt, sem);
return Scope(sem, [&] {
if (!Expression(stmt->condition)) {
@@ -4464,15 +4324,15 @@
}
bool Resolver::ValidateAssignment(const ast::AssignmentStatement* a) {
- auto const* rhs_type = TypeOf(a->rhs);
+ auto const* rhs_ty = TypeOf(a->rhs);
if (a->lhs->Is<ast::PhonyExpression>()) {
// https://www.w3.org/TR/WGSL/#phony-assignment-section
- auto* ty = rhs_type->UnwrapRef();
+ auto* ty = rhs_ty->UnwrapRef();
if (!ty->IsConstructible() &&
!ty->IsAnyOf<sem::Pointer, sem::Texture, sem::Sampler>()) {
AddError(
- "cannot assign '" + TypeNameOf(a->rhs) +
+ "cannot assign '" + TypeNameOf(rhs_ty) +
"' to '_'. '_' can only be assigned a constructible, pointer, "
"texture or sampler type",
a->rhs->source);
@@ -4482,52 +4342,53 @@
}
// https://gpuweb.github.io/gpuweb/wgsl/#assignment-statement
- auto const* lhs_type = TypeOf(a->lhs);
+ auto const* lhs_ty = TypeOf(a->lhs);
if (auto* ident = a->lhs->As<ast::IdentifierExpression>()) {
- if (VariableInfo* var = variable_stack_.Get(ident->symbol)) {
- if (var->kind == VariableKind::kParameter) {
+ if (auto* var = variable_stack_.Get(ident->symbol)) {
+ if (var->Is<sem::Parameter>()) {
AddError("cannot assign to function parameter", a->lhs->source);
AddNote("'" + builder_->Symbols().NameFor(ident->symbol) +
"' is declared here:",
- var->declaration->source);
+ var->Declaration()->source);
return false;
}
- if (var->declaration->is_const) {
+ if (var->Declaration()->is_const) {
AddError("cannot assign to const", a->lhs->source);
AddNote("'" + builder_->Symbols().NameFor(ident->symbol) +
"' is declared here:",
- var->declaration->source);
+ var->Declaration()->source);
return false;
}
}
}
- auto* lhs_ref = lhs_type->As<sem::Reference>();
+ auto* lhs_ref = lhs_ty->As<sem::Reference>();
if (!lhs_ref) {
// LHS is not a reference, so it has no storage.
- AddError("cannot assign to value of type '" + TypeNameOf(a->lhs) + "'",
+ AddError("cannot assign to value of type '" + TypeNameOf(lhs_ty) + "'",
a->lhs->source);
return false;
}
- auto* storage_type = lhs_ref->StoreType();
- auto* value_type = rhs_type->UnwrapRef(); // Implicit load of RHS
+ auto* storage_ty = lhs_ref->StoreType();
+ auto* value_type = rhs_ty->UnwrapRef(); // Implicit load of RHS
// Value type has to match storage type
- if (storage_type != value_type) {
- AddError("cannot assign '" + TypeNameOf(a->rhs) + "' to '" +
- TypeNameOf(a->lhs) + "'",
+ if (storage_ty != value_type) {
+ AddError("cannot assign '" + TypeNameOf(rhs_ty) + "' to '" +
+ TypeNameOf(lhs_ty) + "'",
a->source);
return false;
}
- if (!storage_type->IsConstructible()) {
+ if (!storage_ty->IsConstructible()) {
AddError("storage type of assignment must be constructible", a->source);
return false;
}
if (lhs_ref->Access() == ast::Access::kRead) {
- AddError("cannot store into a read-only type '" + TypeNameOf(a->lhs) + "'",
- a->source);
+ AddError(
+ "cannot store into a read-only type '" + RawTypeNameOf(lhs_ty) + "'",
+ a->source);
return false;
}
return true;
@@ -4537,11 +4398,11 @@
const Source& source,
bool check_global_scope_only) {
if (check_global_scope_only) {
- if (VariableInfo* var = variable_stack_.Get(sym)) {
- if (var->kind == VariableKind::kGlobal) {
+ if (auto* var = variable_stack_.Get(sym)) {
+ if (var->Is<sem::GlobalVariable>()) {
AddError("redefinition of '" + builder_->Symbols().NameFor(sym) + "'",
source);
- AddNote("previous definition is here", var->declaration->source);
+ AddNote("previous definition is here", var->Declaration()->source);
return false;
}
}
@@ -4549,14 +4410,14 @@
if (it != symbol_to_function_.end()) {
AddError("redefinition of '" + builder_->Symbols().NameFor(sym) + "'",
source);
- AddNote("previous definition is here", it->second->declaration->source);
+ AddNote("previous definition is here", it->second->Declaration()->source);
return false;
}
} else {
- if (VariableInfo* var = variable_stack_.Get(sym)) {
+ if (auto* var = variable_stack_.Get(sym)) {
AddError("redefinition of '" + builder_->Symbols().NameFor(sym) + "'",
source);
- AddNote("previous definition is here", var->declaration->source);
+ AddNote("previous definition is here", var->Declaration()->source);
return false;
}
}
@@ -4592,8 +4453,7 @@
for (auto* member : str->Members()) {
if (!ApplyStorageClassUsageToType(sc, member->Type(), usage)) {
std::stringstream err;
- err << "while analysing structure member "
- << str->FriendlyName(builder_->Symbols()) << "."
+ err << "while analysing structure member " << TypeNameOf(str) << "."
<< builder_->Symbols().NameFor(member->Declaration()->symbol);
AddNote(err.str(), member->Declaration()->source);
return false;
@@ -4609,9 +4469,8 @@
if (ast::IsHostShareable(sc) && !IsHostShareable(ty)) {
std::stringstream err;
- err << "Type '" << ty->FriendlyName(builder_->Symbols())
- << "' cannot be used in storage class '" << sc
- << "' as it is non-host-shareable";
+ err << "Type '" << TypeNameOf(ty) << "' cannot be used in storage class '"
+ << sc << "' as it is non-host-shareable";
AddError(err.str(), usage);
return false;
}
@@ -4630,10 +4489,10 @@
variable_stack_.Push();
TINT_DEFER({
- TINT_DEFER(variable_stack_.Pop());
current_block_ = prev_current_block;
current_compound_statement_ = prev_current_compound_statement;
current_statement_ = prev_current_statement;
+ variable_stack_.Pop();
});
return callback();
@@ -4671,26 +4530,18 @@
diagnostics_.add_note(diag::System::Resolver, msg, source);
}
-Resolver::VariableInfo::VariableInfo(const ast::Variable* decl,
- sem::Type* ty,
- const std::string& tn,
- ast::StorageClass sc,
- ast::Access ac,
- VariableKind k,
- uint32_t idx)
- : declaration(decl),
- type(ty),
- type_name(tn),
- storage_class(sc),
- access(ac),
- kind(k),
- index(idx) {}
-
-Resolver::VariableInfo::~VariableInfo() = default;
-
-Resolver::FunctionInfo::FunctionInfo(const ast::Function* decl)
- : declaration(decl) {}
-Resolver::FunctionInfo::~FunctionInfo() = default;
+template <typename SEM, typename AST_OR_TYPE>
+const sem::Info::GetResultType<SEM, AST_OR_TYPE>* Resolver::Sem(
+ const AST_OR_TYPE* ast) {
+ auto* sem = builder_->Sem().Get<SEM>(ast);
+ if (!sem) {
+ TINT_ICE(Resolver, diagnostics_)
+ << "AST node '" << ast->TypeInfo().name << "' had no semantic info\n"
+ << "At: " << ast->source << "\n"
+ << "Pointer: " << ast;
+ }
+ return sem;
+}
} // namespace resolver
} // namespace tint
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index edd8759..5730a55 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -95,79 +95,9 @@
/// Describes the context in which a variable is declared
enum class VariableKind { kParameter, kLocal, kGlobal };
- /// Structure holding semantic information about a variable.
- /// Used to build the sem::Variable nodes at the end of resolving.
- struct VariableInfo {
- VariableInfo(const ast::Variable* decl,
- sem::Type* type,
- const std::string& type_name,
- ast::StorageClass storage_class,
- ast::Access ac,
- VariableKind k,
- uint32_t idx);
- ~VariableInfo();
-
- ast::Variable const* const declaration;
- sem::Type* type;
- std::string const type_name;
- ast::StorageClass storage_class;
- ast::Access const access;
- std::vector<const ast::IdentifierExpression*> users;
- sem::BindingPoint binding_point;
- VariableKind kind;
- uint32_t index = 0; // Parameter index, if kind == kParameter
- uint16_t constant_id = 0;
- };
-
- struct IntrinsicCallInfo {
- const ast::CallExpression* call;
- const sem::Intrinsic* intrinsic;
- };
-
std::set<std::pair<const sem::Struct*, ast::StorageClass>>
valid_struct_storage_layouts_;
- /// Structure holding semantic information about a function.
- /// Used to build the sem::Function nodes at the end of resolving.
- struct FunctionInfo {
- explicit FunctionInfo(const ast::Function* decl);
- ~FunctionInfo();
-
- const ast::Function* const declaration;
- std::vector<VariableInfo*> parameters;
- utils::UniqueVector<VariableInfo*> referenced_module_vars;
- utils::UniqueVector<VariableInfo*> local_referenced_module_vars;
- std::vector<const ast::ReturnStatement*> return_statements;
- std::vector<const ast::CallExpression*> callsites;
- sem::Type* return_type = nullptr;
- std::string return_type_name;
- std::array<sem::WorkgroupDimension, 3> workgroup_size;
- std::vector<IntrinsicCallInfo> intrinsic_calls;
-
- // List of transitive calls this function makes
- utils::UniqueVector<FunctionInfo*> transitive_calls;
-
- // List of entry point functions that transitively call this function
- utils::UniqueVector<FunctionInfo*> ancestor_entry_points;
- };
-
- /// Structure holding semantic information about an expression.
- /// Used to build the sem::Expression nodes at the end of resolving.
- struct ExpressionInfo {
- sem::Type const* type;
- std::string const type_name; // Declared type name
- sem::Statement* statement;
- sem::Constant constant_value;
- };
-
- /// Structure holding semantic information about a call expression to an
- /// ast::Function.
- /// Used to build the sem::Call nodes at the end of resolving.
- struct FunctionCallInfo {
- FunctionInfo* function;
- sem::Statement* statement;
- };
-
/// Structure holding semantic information about a block (i.e. scope), such as
/// parent block and variables declared in the block.
/// Used to validate variable scoping rules.
@@ -231,35 +161,40 @@
const ast::ExpressionList& params,
uint32_t* id);
- void set_referenced_from_function_if_needed(VariableInfo* var, bool local);
-
+ //////////////////////////////////////////////////////////////////////////////
// AST and Type traversal methods
+ //////////////////////////////////////////////////////////////////////////////
+
+ // Expression resolving methods
+ // Returns the semantic node pointer on success, nullptr on failure.
+ sem::Expression* ArrayAccessor(const ast::ArrayAccessorExpression*);
+ sem::Expression* Binary(const ast::BinaryExpression*);
+ sem::Expression* Bitcast(const ast::BitcastExpression*);
+ sem::Expression* Call(const ast::CallExpression*);
+ sem::Expression* Constructor(const ast::ConstructorExpression*);
+ sem::Expression* Expression(const ast::Expression*);
+ sem::Function* Function(const ast::Function*);
+ sem::Call* FunctionCall(const ast::CallExpression*);
+ sem::Expression* Identifier(const ast::IdentifierExpression*);
+ sem::Call* IntrinsicCall(const ast::CallExpression*, sem::IntrinsicType);
+ sem::Expression* MemberAccessor(const ast::MemberAccessorExpression*);
+ sem::Expression* UnaryOp(const ast::UnaryOpExpression*);
+
+ // Statement resolving methods
// Each return true on success, false on failure.
- bool ArrayAccessor(const ast::ArrayAccessorExpression*);
bool Assignment(const ast::AssignmentStatement* a);
- bool Binary(const ast::BinaryExpression*);
- bool Bitcast(const ast::BitcastExpression*);
bool BlockStatement(const ast::BlockStatement*);
- bool Call(const ast::CallExpression*);
bool CaseStatement(const ast::CaseStatement*);
- bool Constructor(const ast::ConstructorExpression*);
bool ElseStatement(const ast::ElseStatement*);
- bool Expression(const ast::Expression*);
bool ForLoopStatement(const ast::ForLoopStatement*);
- bool Function(const ast::Function*);
- bool FunctionCall(const ast::CallExpression* call);
- bool GlobalVariable(const ast::Variable* var);
- bool Identifier(const ast::IdentifierExpression*);
- bool IfStatement(const ast::IfStatement*);
- bool IntrinsicCall(const ast::CallExpression*, sem::IntrinsicType);
- bool LoopStatement(const ast::LoopStatement*);
- bool MemberAccessor(const ast::MemberAccessorExpression*);
bool Parameter(const ast::Variable* param);
+ bool GlobalVariable(const ast::Variable* var);
+ bool IfStatement(const ast::IfStatement*);
+ bool LoopStatement(const ast::LoopStatement*);
bool Return(const ast::ReturnStatement* ret);
bool Statement(const ast::Statement*);
bool Statements(const ast::StatementList&);
bool SwitchStatement(const ast::SwitchStatement* s);
- bool UnaryOp(const ast::UnaryOpExpression*);
bool VariableDeclStatement(const ast::VariableDeclStatement*);
// AST and Type validation methods
@@ -270,18 +205,16 @@
uint32_t el_align,
const Source& source);
bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s);
- bool ValidateAtomicVariable(const VariableInfo* info);
+ bool ValidateAtomicVariable(const sem::Variable* var);
bool ValidateAssignment(const ast::AssignmentStatement* a);
bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
const sem::Type* storage_type,
const bool is_input);
- bool ValidateCall(const ast::CallExpression* call);
- bool ValidateCallStatement(const ast::CallStatement* stmt);
- bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info);
- bool ValidateFunction(const ast::Function* func, const FunctionInfo* info);
- bool ValidateFunctionCall(const ast::CallExpression* call,
- const FunctionInfo* target);
- bool ValidateGlobalVariable(const VariableInfo* var);
+ bool ValidateCall(const sem::Call* call);
+ bool ValidateEntryPoint(const sem::Function* func);
+ bool ValidateFunction(const sem::Function* func);
+ bool ValidateFunctionCall(const sem::Call* call);
+ bool ValidateGlobalVariable(const sem::Variable* var);
bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco,
const sem::Type* storage_type);
bool ValidateLocationDecoration(const ast::LocationDecoration* location,
@@ -291,11 +224,11 @@
const bool is_input = false);
bool ValidateMatrix(const sem::Matrix* ty, const Source& source);
bool ValidateFunctionParameter(const ast::Function* func,
- const VariableInfo* info);
+ const sem::Variable* var);
bool ValidateNoDuplicateDefinition(Symbol sym,
const Source& source,
bool check_global_scope_only = false);
- bool ValidateParameter(const ast::Function* func, const VariableInfo* info);
+ bool ValidateParameter(const ast::Function* func, const sem::Variable* var);
bool ValidateReturn(const ast::ReturnStatement* ret);
bool ValidateStatements(const ast::StatementList& stmts);
bool ValidateStorageTexture(const ast::StorageTexture* t);
@@ -303,33 +236,30 @@
bool ValidateStructureConstructor(const ast::TypeConstructorExpression* ctor,
const sem::Struct* struct_type);
bool ValidateSwitch(const ast::SwitchStatement* s);
- bool ValidateVariable(const VariableInfo* info);
+ bool ValidateVariable(const sem::Variable* var);
bool ValidateVariableConstructor(const ast::Variable* var,
ast::StorageClass storage_class,
const sem::Type* storage_type,
- const std::string& type_name,
- const sem::Type* rhs_type,
- const std::string& rhs_type_name);
+ const sem::Type* rhs_type);
bool ValidateVector(const sem::Vector* ty, const Source& source);
bool ValidateVectorConstructor(const ast::TypeConstructorExpression* ctor,
- const sem::Vector* vec_type,
- const std::string& type_name);
+ const sem::Vector* vec_type);
bool ValidateMatrixConstructor(const ast::TypeConstructorExpression* ctor,
- const sem::Matrix* matrix_type,
- const std::string& type_name);
+ const sem::Matrix* matrix_type);
bool ValidateScalarConstructor(const ast::TypeConstructorExpression* ctor,
- const sem::Type* type,
- const std::string& type_name);
+ const sem::Type* type);
bool ValidateArrayConstructor(const ast::TypeConstructorExpression* ctor,
const sem::Array* arr_type);
bool ValidateTypeDecl(const ast::TypeDecl* named_type) const;
- bool ValidateTextureIntrinsicFunction(const ast::CallExpression* ast_call,
- const sem::Call* sem_call);
+ bool ValidateTextureIntrinsicFunction(const sem::Call* call);
bool ValidateNoDuplicateDecorations(const ast::DecorationList& decorations);
// sem::Struct is assumed to have at least one member
bool ValidateStorageClassLayout(const sem::Struct* type,
ast::StorageClass sc);
- bool ValidateStorageClassLayout(const VariableInfo* info);
+ bool ValidateStorageClassLayout(const sem::Variable* var);
+
+ /// Resolves the WorkgroupSize for the given function
+ bool WorkgroupSizeFor(const ast::Function*, sem::WorkgroupSize& ws);
/// @returns the sem::Type for the ast::Type `ty`, building it if it
/// hasn't been constructed already. If an error is raised, nullptr is
@@ -355,16 +285,16 @@
/// raised. raised, nullptr is returned.
sem::Struct* Structure(const ast::Struct* str);
- /// @returns the VariableInfo for the variable `var`, building it if it hasn't
- /// been constructed already. If an error is raised, nullptr is returned.
+ /// @returns the semantic info for the variable `var`. If an error is raised,
+ /// nullptr is returned.
/// @note this method does not resolve the decorations as these are
/// context-dependent (global, local, parameter)
/// @param var the variable to create or return the `VariableInfo` for
/// @param kind what kind of variable we are declaring
/// @param index the index of the parameter, if this variable is a parameter
- VariableInfo* Variable(const ast::Variable* var,
- VariableKind kind,
- uint32_t index = 0);
+ sem::Variable* Variable(const ast::Variable* var,
+ VariableKind kind,
+ uint32_t index = 0);
/// Records the storage class usage for the given type, and any transient
/// dependencies of the type. Validates that the type can be used for the
@@ -389,23 +319,17 @@
/// @param expr the expression
sem::Type* TypeOf(const ast::Expression* expr);
- /// @returns the declared type name of the ast::Expression `expr`
- /// @param expr the type name
- std::string TypeNameOf(const ast::Expression* expr);
+ /// @returns the type name of the given semantic type, unwrapping references.
+ std::string TypeNameOf(const sem::Type* ty);
+
+ /// @returns the type name of the given semantic type, without unwrapping
+ /// references.
+ std::string RawTypeNameOf(const sem::Type* ty);
/// @returns the semantic type of the AST literal `lit`
/// @param lit the literal
sem::Type* TypeOf(const ast::Literal* lit);
- /// Records the semantic information for the expression node with the resolved
- /// type `type` and optional declared type name `type_name`.
- /// @param expr the expression
- /// @param type the resolved type
- /// @param type_name the declared type name
- void SetExprInfo(const ast::Expression* expr,
- const sem::Type* type,
- std::string type_name = "");
-
/// Assigns `stmt` to #current_statement_, #current_compound_statement_, and
/// possibly #current_block_, pushes the variable scope, then calls
/// `callback`. Before returning #current_statement_,
@@ -437,16 +361,13 @@
void AddNote(const std::string& msg, const Source& source) const;
template <typename CALLBACK>
- void TraverseCallChain(FunctionInfo* from,
- FunctionInfo* to,
+ void TraverseCallChain(const sem::Function* from,
+ const sem::Function* to,
CALLBACK&& callback) const;
//////////////////////////////////////////////////////////////////////////////
/// Constant value evaluation methods
//////////////////////////////////////////////////////////////////////////////
- /// @return the Constant value of the given Expression
- sem::Constant ConstantValueOf(const ast::Expression* expr);
-
/// Cast `Value` to `target_type`
/// @return the casted value
sem::Constant ConstantCast(const sem::Constant& value,
@@ -461,29 +382,27 @@
const ast::TypeConstructorExpression* type_ctor,
const sem::Type* type);
+ /// Sem is a helper for obtaining the semantic node for the given AST node.
+ template <typename SEM = sem::Info::InferFromAST,
+ typename AST_OR_TYPE = CastableBase>
+ const sem::Info::GetResultType<SEM, AST_OR_TYPE>* Sem(const AST_OR_TYPE* ast);
+
ProgramBuilder* const builder_;
diag::List& diagnostics_;
std::unique_ptr<IntrinsicTable> const intrinsic_table_;
- ScopeStack<VariableInfo*> variable_stack_;
- std::unordered_map<Symbol, FunctionInfo*> symbol_to_function_;
- std::vector<FunctionInfo*> entry_points_;
+ ScopeStack<sem::Variable*> variable_stack_;
+ std::unordered_map<Symbol, sem::Function*> symbol_to_function_;
+ std::vector<sem::Function*> entry_points_;
std::unordered_map<const sem::Type*, const Source&> atomic_composite_info_;
- std::unordered_map<const ast::Function*, FunctionInfo*> function_to_info_;
- std::unordered_map<const ast::Variable*, VariableInfo*> variable_to_info_;
- std::unordered_map<const ast::CallExpression*, FunctionCallInfo>
- function_calls_;
- std::unordered_map<const ast::Expression*, ExpressionInfo> expr_info_;
std::unordered_map<Symbol, TypeDeclInfo> named_type_info_;
std::unordered_set<const ast::Node*> marked_;
- std::unordered_map<uint32_t, const VariableInfo*> constant_ids_;
+ std::unordered_map<uint32_t, const sem::Variable*> constant_ids_;
- FunctionInfo* current_function_ = nullptr;
+ sem::Function* current_function_ = nullptr;
sem::Statement* current_statement_ = nullptr;
sem::CompoundStatement* current_compound_statement_ = nullptr;
sem::BlockStatement* current_block_ = nullptr;
- BlockAllocator<VariableInfo> variable_infos_;
- BlockAllocator<FunctionInfo> function_infos_;
};
} // namespace resolver
diff --git a/src/resolver/resolver_constants.cc b/src/resolver/resolver_constants.cc
index e28ed03..fb59ff3 100644
--- a/src/resolver/resolver_constants.cc
+++ b/src/resolver/resolver_constants.cc
@@ -15,6 +15,7 @@
#include "src/resolver/resolver.h"
#include "src/sem/constant.h"
+#include "src/utils/get_or_create.h"
namespace tint {
namespace resolver {
@@ -26,46 +27,6 @@
} // namespace
-sem::Constant Resolver::ConstantCast(const sem::Constant& value,
- const sem::Type* target_elem_type) {
- if (value.ElementType() == target_elem_type) {
- return value;
- }
-
- sem::Constant::Scalars elems;
- for (size_t i = 0; i < value.Elements().size(); ++i) {
- if (target_elem_type->Is<sem::I32>()) {
- elems.emplace_back(
- value.WithScalarAt(i, [](auto&& s) { return static_cast<i32>(s); }));
- } else if (target_elem_type->Is<sem::U32>()) {
- elems.emplace_back(
- value.WithScalarAt(i, [](auto&& s) { return static_cast<u32>(s); }));
- } else if (target_elem_type->Is<sem::F32>()) {
- elems.emplace_back(
- value.WithScalarAt(i, [](auto&& s) { return static_cast<f32>(s); }));
- } else if (target_elem_type->Is<sem::Bool>()) {
- elems.emplace_back(
- value.WithScalarAt(i, [](auto&& s) { return static_cast<bool>(s); }));
- }
- }
-
- auto* target_type =
- value.Type()->Is<sem::Vector>()
- ? builder_->create<sem::Vector>(target_elem_type,
- static_cast<uint32_t>(elems.size()))
- : target_elem_type;
-
- return sem::Constant(target_type, elems);
-}
-
-sem::Constant Resolver::ConstantValueOf(const ast::Expression* expr) {
- auto it = expr_info_.find(expr);
- if (it != expr_info_.end()) {
- return it->second.constant_value;
- }
- return {};
-}
-
sem::Constant Resolver::EvaluateConstantValue(const ast::Expression* expr,
const sem::Type* type) {
if (auto* e = expr->As<ast::ScalarConstructorExpression>()) {
@@ -131,11 +92,11 @@
// type_ctor's type.
sem::Constant::Scalars elems;
for (auto* cv : ctor_values) {
- auto value = ConstantValueOf(cv);
- if (!value.IsValid()) {
+ auto* expr = builder_->Sem().Get(cv);
+ if (!expr || !expr->ConstantValue()) {
return {};
}
- auto cast = ConstantCast(value, elem_type);
+ auto cast = ConstantCast(expr->ConstantValue(), elem_type);
elems.insert(elems.end(), cast.Elements().begin(), cast.Elements().end());
}
@@ -149,5 +110,37 @@
return sem::Constant(type, std::move(elems));
}
+sem::Constant Resolver::ConstantCast(const sem::Constant& value,
+ const sem::Type* target_elem_type) {
+ if (value.ElementType() == target_elem_type) {
+ return value;
+ }
+
+ sem::Constant::Scalars elems;
+ for (size_t i = 0; i < value.Elements().size(); ++i) {
+ if (target_elem_type->Is<sem::I32>()) {
+ elems.emplace_back(
+ value.WithScalarAt(i, [](auto&& s) { return static_cast<i32>(s); }));
+ } else if (target_elem_type->Is<sem::U32>()) {
+ elems.emplace_back(
+ value.WithScalarAt(i, [](auto&& s) { return static_cast<u32>(s); }));
+ } else if (target_elem_type->Is<sem::F32>()) {
+ elems.emplace_back(
+ value.WithScalarAt(i, [](auto&& s) { return static_cast<f32>(s); }));
+ } else if (target_elem_type->Is<sem::Bool>()) {
+ elems.emplace_back(
+ value.WithScalarAt(i, [](auto&& s) { return static_cast<bool>(s); }));
+ }
+ }
+
+ auto* target_type =
+ value.Type()->Is<sem::Vector>()
+ ? builder_->create<sem::Vector>(target_elem_type,
+ static_cast<uint32_t>(elems.size()))
+ : target_elem_type;
+
+ return sem::Constant(target_type, elems);
+}
+
} // namespace resolver
} // namespace tint
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index ddf4d36..99f206a 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -922,8 +922,8 @@
auto* foo_sem = Sem().Get(foo);
ASSERT_NE(foo_sem, nullptr);
ASSERT_EQ(foo_sem->CallSites().size(), 2u);
- EXPECT_EQ(foo_sem->CallSites()[0], call_1);
- EXPECT_EQ(foo_sem->CallSites()[1], call_2);
+ EXPECT_EQ(foo_sem->CallSites()[0]->Declaration(), call_1);
+ EXPECT_EQ(foo_sem->CallSites()[1]->Declaration(), call_2);
auto* bar_sem = Sem().Get(bar);
ASSERT_NE(bar_sem, nullptr);
@@ -1908,17 +1908,17 @@
const auto& b_eps = func_b_sem->AncestorEntryPoints();
ASSERT_EQ(2u, b_eps.size());
- EXPECT_EQ(Symbols().Register("ep_1"), b_eps[0]);
- EXPECT_EQ(Symbols().Register("ep_2"), b_eps[1]);
+ EXPECT_EQ(Symbols().Register("ep_1"), b_eps[0]->Declaration()->symbol);
+ EXPECT_EQ(Symbols().Register("ep_2"), b_eps[1]->Declaration()->symbol);
const auto& a_eps = func_a_sem->AncestorEntryPoints();
ASSERT_EQ(1u, a_eps.size());
- EXPECT_EQ(Symbols().Register("ep_1"), a_eps[0]);
+ EXPECT_EQ(Symbols().Register("ep_1"), a_eps[0]->Declaration()->symbol);
const auto& c_eps = func_c_sem->AncestorEntryPoints();
ASSERT_EQ(2u, c_eps.size());
- EXPECT_EQ(Symbols().Register("ep_1"), c_eps[0]);
- EXPECT_EQ(Symbols().Register("ep_2"), c_eps[1]);
+ EXPECT_EQ(Symbols().Register("ep_1"), c_eps[0]->Declaration()->symbol);
+ EXPECT_EQ(Symbols().Register("ep_2"), c_eps[1]->Declaration()->symbol);
EXPECT_TRUE(ep_1_sem->AncestorEntryPoints().empty());
EXPECT_TRUE(ep_2_sem->AncestorEntryPoints().empty());
diff --git a/src/resolver/storage_class_layout_validation_test.cc b/src/resolver/storage_class_layout_validation_test.cc
index 5339f13..3ad7d3f 100644
--- a/src/resolver/storage_class_layout_validation_test.cc
+++ b/src/resolver/storage_class_layout_validation_test.cc
@@ -179,11 +179,11 @@
ASSERT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
- R"(56:78 error: the offset of a struct member of type 'Inner' in storage class 'uniform' must be a multiple of 16 bytes, but 'inner' is currently at offset 4. Consider setting [[align(16)]] on this member
+ R"(56:78 error: the offset of a struct member of type '[[stride(16)]] array<f32, 10>' in storage class 'uniform' must be a multiple of 16 bytes, but 'inner' is currently at offset 4. Consider setting [[align(16)]] on this member
12:34 note: see layout of struct:
/* align(4) size(164) */ struct Outer {
/* offset( 0) align(4) size( 4) */ scalar : f32;
-/* offset( 4) align(4) size(160) */ inner : Inner;
+/* offset( 4) align(4) size(160) */ inner : [[stride(16)]] array<f32, 10>;
/* */ };
78:90 note: see declaration of variable)");
}
@@ -351,7 +351,7 @@
R"(34:56 error: uniform storage requires that array elements be aligned to 16 bytes, but array stride of 'inner' is currently 8. Consider setting [[stride(16)]] on the array type
12:34 note: see layout of struct:
/* align(4) size(84) */ struct Outer {
-/* offset( 0) align(4) size(80) */ inner : Inner;
+/* offset( 0) align(4) size(80) */ inner : [[stride(8)]] array<f32, 10>;
/* offset(80) align(4) size( 4) */ scalar : i32;
/* */ };
78:90 note: see declaration of variable)");
diff --git a/src/resolver/storage_class_validation_test.cc b/src/resolver/storage_class_validation_test.cc
index 9a60ac4..f902a9e 100644
--- a/src/resolver/storage_class_validation_test.cc
+++ b/src/resolver/storage_class_validation_test.cc
@@ -96,7 +96,8 @@
EXPECT_EQ(
r()->error(),
- R"(56:78 error: variables declared in the <storage> storage class must be of a structure type)");
+ R"(56:78 error: Type 'bool' cannot be used in storage class 'storage' as it is non-host-shareable
+56:78 note: while instantiating variable g)");
}
TEST_F(ResolverStorageClassValidationTest, NotStorage_AccessMode) {
@@ -194,7 +195,8 @@
EXPECT_EQ(
r()->error(),
- R"(56:78 error: variables declared in the <uniform> storage class must be of a structure type)");
+ R"(56:78 error: Type 'bool' cannot be used in storage class 'uniform' as it is non-host-shareable
+56:78 note: while instantiating variable g)");
}
TEST_F(ResolverStorageClassValidationTest, UniformBufferPointer) {
@@ -243,7 +245,8 @@
EXPECT_EQ(
r()->error(),
- R"(56:78 error: variables declared in the <uniform> storage class must be of a structure type)");
+ R"(56:78 error: Type 'bool' cannot be used in storage class 'uniform' as it is non-host-shareable
+56:78 note: while instantiating variable g)");
}
TEST_F(ResolverStorageClassValidationTest, UniformBufferNoBlockDecoration) {
diff --git a/src/resolver/type_constructor_validation_test.cc b/src/resolver/type_constructor_validation_test.cc
index d72cd82..c369f0b 100644
--- a/src/resolver/type_constructor_validation_test.cc
+++ b/src/resolver/type_constructor_validation_test.cc
@@ -1553,7 +1553,7 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: type in vector constructor does not match vector "
- "type: expected 'f32', found 'UnsignedInt'");
+ "type: expected 'f32', found 'u32'");
}
TEST_F(ResolverTypeConstructorValidationTest,
@@ -1638,10 +1638,9 @@
uint32_t columns;
};
-static std::string MatrixStr(const MatrixDimensions& dimensions,
- std::string subtype = "f32") {
+static std::string MatrixStr(const MatrixDimensions& dimensions) {
return "mat" + std::to_string(dimensions.columns) + "x" +
- std::to_string(dimensions.rows) + "<" + subtype + ">";
+ std::to_string(dimensions.rows) + "<f32>";
}
using MatrixConstructorTest = ResolverTestWithParam<MatrixDimensions>;
@@ -1919,9 +1918,9 @@
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
- EXPECT_THAT(r()->error(), HasSubstr("12:1 error: invalid constructor for " +
- MatrixStr(param, "Float32") +
- "\n\n3 candidates available:"));
+ EXPECT_THAT(r()->error(),
+ HasSubstr("12:1 error: invalid constructor for " +
+ MatrixStr(param) + "\n\n3 candidates available:"));
}
TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Success) {
diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc
index 55a9c1f..bdad5fe 100644
--- a/src/resolver/validation_test.cc
+++ b/src/resolver/validation_test.cc
@@ -499,9 +499,10 @@
Func("func", {p}, ty.f32(), {Decl(x), Return(x)});
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(),
- "error: invalid member accessor expression. Expected vector or "
- "struct, got 'ptr<function, vec4<f32>>'");
+ EXPECT_EQ(
+ r()->error(),
+ "error: invalid member accessor expression. "
+ "Expected vector or struct, got 'ptr<function, vec4<f32>, read_write>'");
}
TEST_F(ResolverValidationTest,
diff --git a/src/resolver/var_let_validation_test.cc b/src/resolver/var_let_validation_test.cc
index 762ac00..48469ad 100644
--- a/src/resolver/var_let_validation_test.cc
+++ b/src/resolver/var_let_validation_test.cc
@@ -120,7 +120,7 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
- R"(3:3 error: cannot initialize let of type 'I32' with value of type 'u32')");
+ R"(3:3 error: cannot initialize let of type 'i32' with value of type 'u32')");
}
TEST_F(ResolverVarLetValidationTest, VarConstructorWrongTypeViaAlias) {
@@ -131,7 +131,7 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
- R"(3:3 error: cannot initialize var of type 'I32' with value of type 'u32')");
+ R"(3:3 error: cannot initialize var of type 'i32' with value of type 'u32')");
}
TEST_F(ResolverVarLetValidationTest, LetOfPtrConstructedWithRef) {
@@ -147,7 +147,7 @@
EXPECT_EQ(
r()->error(),
- R"(12:34 error: cannot initialize let of type 'ptr<function, f32>' with value of type 'f32')");
+ R"(12:34 error: cannot initialize let of type 'ptr<function, f32, read_write>' with value of type 'f32')");
}
TEST_F(ResolverVarLetValidationTest, LocalVarRedeclared) {