Resolver: Move common logic into Variable()
Variable() is called for globals, locals and parameters. Much of the logic is the same.
Move all the common logic down into Variable(). This:
* Removes some yucky default parameters
* Adds type validation that was missing for globals (broken tests fixed)
* Gives me a single place to implement the Reference type wrapping
Bug: tint:727
Change-Id: I70f4a3603d7fa781da938508aa2a1bc80ec15d77
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/50580
Auto-Submit: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 4c6eb94..8095bea 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -413,20 +413,57 @@
return s;
}
-Resolver::VariableInfo* Resolver::Variable(
- ast::Variable* var,
- const sem::Type* type, /* = nullptr */
- std::string type_name /* = "" */) {
- auto it = variable_to_info_.find(var);
- if (it != variable_to_info_.end()) {
- return it->second;
+Resolver::VariableInfo* Resolver::Variable(ast::Variable* var,
+ bool is_parameter) {
+ if (variable_to_info_.count(var)) {
+ TINT_ICE(diagnostics_) << "Variable "
+ << builder_->Symbols().NameFor(var->symbol())
+ << " already resolved";
+ return nullptr;
}
- if (type == nullptr && var->type()) {
- type = Type(var->type());
- type_name = var->type()->FriendlyName(builder_->Symbols());
+ // If the variable has a declared type, resolve it.
+ std::string type_name;
+ const sem::Type* type = nullptr;
+ if (auto* ty = var->type()) {
+ type_name = ty->FriendlyName(builder_->Symbols());
+ type = Type(ty);
+ if (!type) {
+ return nullptr;
+ }
}
- if (type == nullptr) {
+
+ // Does the variable have a constructor?
+ if (auto* ctor = var->constructor()) {
+ Mark(var->constructor());
+ if (!Expression(var->constructor())) {
+ return nullptr;
+ }
+
+ // Fetch the constructor's type
+ auto* rhs_type = TypeOf(ctor);
+ if (!rhs_type) {
+ return nullptr;
+ }
+
+ // If the variable has no declared type, infer it from the RHS
+ if (type == nullptr) {
+ type_name = TypeNameOf(ctor);
+ type = rhs_type->UnwrapPtr();
+ }
+
+ if (!IsValidAssignment(type, rhs_type)) {
+ diagnostics_.add_error(
+ "variable of type '" + type_name +
+ "' cannot be initialized with a value of type '" +
+ TypeNameOf(ctor) + "'",
+ var->source());
+ return nullptr;
+ }
+ } else if (var->is_const() && !is_parameter &&
+ !ast::HasDecoration<ast::OverrideDecoration>(var->decorations())) {
+ diagnostics_.add_error("let declarations must have initializers",
+ var->source());
return nullptr;
}
@@ -446,7 +483,7 @@
return false;
}
- auto* info = Variable(var);
+ auto* info = Variable(var, /* is_parameter */ false);
if (!info) {
return false;
}
@@ -472,20 +509,6 @@
info->binding_point = {bp.group->value(), bp.binding->value()};
}
- if (var->has_constructor()) {
- Mark(var->constructor());
- if (!Expression(var->constructor())) {
- return false;
- }
- } else {
- if (var->is_const() &&
- !ast::HasDecoration<ast::OverrideDecoration>(var->decorations())) {
- diagnostics_.add_error("let declarations must have initializers",
- var->source());
- return false;
- }
- }
-
if (!ValidateGlobalVariable(info)) {
return false;
}
@@ -1020,7 +1043,7 @@
variable_stack_.push_scope();
for (auto* param : func->params()) {
Mark(param);
- auto* param_info = Variable(param);
+ auto* param_info = Variable(param, /* is_parameter */ true);
if (!param_info) {
return false;
}
@@ -2072,17 +2095,6 @@
ast::Variable* var = stmt->variable();
Mark(var);
- // If the variable has a declared type, resolve it.
- std::string type_name;
- const sem::Type* type = nullptr;
- if (auto* ast_ty = var->type()) {
- type_name = ast_ty->FriendlyName(builder_->Symbols());
- type = Type(ast_ty);
- if (!type) {
- return false;
- }
- }
-
bool is_global = false;
if (variable_stack_.get(var->symbol(), nullptr, &is_global)) {
const char* error_code = is_global ? "v-0013" : "v-0014";
@@ -2093,33 +2105,9 @@
return false;
}
- if (auto* ctor = stmt->variable()->constructor()) {
- Mark(ctor);
- if (!Expression(ctor)) {
- return false;
- }
- auto* rhs_type = TypeOf(ctor);
-
- // If the variable has no type, infer it from the rhs
- if (type == nullptr) {
- type_name = TypeNameOf(ctor);
- type = rhs_type->UnwrapPtr();
- }
-
- if (!IsValidAssignment(type, rhs_type)) {
- diagnostics_.add_error(
- "variable of type '" + type_name +
- "' cannot be initialized with a value of type '" +
- TypeNameOf(ctor) + "'",
- stmt->source());
- return false;
- }
- } else {
- if (stmt->variable()->is_const()) {
- diagnostics_.add_error("let declarations must have initializers",
- var->source());
- return false;
- }
+ auto* info = Variable(var, /* is_parameter */ false);
+ if (!info) {
+ return false;
}
for (auto* deco : var->decorations()) {
@@ -2127,13 +2115,6 @@
Mark(deco);
}
- auto* info = Variable(var, type, type_name);
- if (!info) {
- return false;
- }
- // TODO(bclayton): Remove this and fix tests. We're overriding the semantic
- // type stored in info->type here with a possibly non-canonicalized type.
- info->type = const_cast<sem::Type*>(type);
variable_stack_.set(var->symbol(), info);
current_block_->decls.push_back(var);
@@ -2212,8 +2193,7 @@
TINT_ASSERT(type.sem);
}
if (expr_info_.count(expr)) {
- TINT_ICE(builder_->Diagnostics())
- << "SetType() called twice for the same expression";
+ TINT_ICE(diagnostics_) << "SetType() called twice for the same expression";
}
expr_info_.emplace(expr, ExpressionInfo{type, type_name, current_statement_});
}
@@ -2260,9 +2240,8 @@
} else {
auto* sem_user = sem_expr->As<sem::VariableUser>();
if (!sem_user) {
- TINT_ICE(builder_->Diagnostics())
- << "expected sem::VariableUser, got "
- << sem_expr->TypeInfo().name;
+ TINT_ICE(diagnostics_) << "expected sem::VariableUser, got "
+ << sem_expr->TypeInfo().name;
}
sem_var->AddUser(sem_user);
}
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index 02dce4f..95a7365 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -262,13 +262,11 @@
/// @returns the VariableInfo for the variable `var`, building it if it hasn't
/// been constructed already. If an error is raised, nullptr is returned.
+ /// @note this method does not resolve the decorations as these are
+ /// context-dependent (global, local, parameter)
/// @param var the variable to create or return the `VariableInfo` for
- /// @param type optional type of `var` to use instead of `var->type()`.
- /// @param type_name optional type name of `var` to use instead of
- /// `var->type()->FriendlyName()`.
- VariableInfo* Variable(ast::Variable* var,
- const sem::Type* type = nullptr,
- std::string type_name = "");
+ /// @param is_parameter true if the variable represents a parameter
+ VariableInfo* Variable(ast::Variable* var, bool is_parameter);
/// Records the storage class usage for the given type, and any transient
/// dependencies of the type. Validates that the type can be used for the
diff --git a/src/resolver/type_validation_test.cc b/src/resolver/type_validation_test.cc
index 82a231e..a97d677 100644
--- a/src/resolver/type_validation_test.cc
+++ b/src/resolver/type_validation_test.cc
@@ -90,8 +90,8 @@
// const<in> global_var: f32;
AST().AddGlobalVariable(
create<ast::Variable>(Source{{12, 34}}, Symbols().Register("global_var"),
- ast::StorageClass::kInput, ty.f32(), true, nullptr,
- ast::DecorationList{}));
+ ast::StorageClass::kInput, ty.f32(), true,
+ Expr(1.23f), ast::DecorationList{}));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
@@ -113,7 +113,7 @@
Global("global_var0", ty.f32(), ast::StorageClass::kPrivate, Expr(0.1f));
Global(Source{{12, 34}}, "global_var1", ty.f32(), ast::StorageClass::kPrivate,
- Expr(0));
+ Expr(1.0f));
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc
index 53c087c..87a9608 100644
--- a/src/resolver/validation_test.cc
+++ b/src/resolver/validation_test.cc
@@ -164,10 +164,8 @@
TEST_F(ResolverValidationTest,
Stmt_VariableDecl_MismatchedTypeScalarConstructor) {
u32 unsigned_value = 2u; // Type does not match variable type
- auto* var =
- Var("my_var", ty.i32(), ast::StorageClass::kNone, Expr(unsigned_value));
-
- auto* decl = Decl(Source{{{3, 3}, {3, 22}}}, var);
+ auto* decl = Decl(Var(Source{{3, 3}}, "my_var", ty.i32(),
+ ast::StorageClass::kNone, Expr(unsigned_value)));
WrapInFunction(decl);
EXPECT_FALSE(r()->Resolve());
@@ -181,10 +179,8 @@
auto* my_int = ty.alias("MyInt", ty.i32());
AST().AddConstructedType(my_int);
u32 unsigned_value = 2u; // Type does not match variable type
- auto* var =
- Var("my_var", my_int, ast::StorageClass::kNone, Expr(unsigned_value));
-
- auto* decl = Decl(Source{{{3, 3}, {3, 22}}}, var);
+ auto* decl = Decl(Var(Source{{3, 3}}, "my_var", my_int,
+ ast::StorageClass::kNone, Expr(unsigned_value)));
WrapInFunction(decl);
EXPECT_FALSE(r()->Resolve());
diff --git a/src/writer/spirv/builder_function_variable_test.cc b/src/writer/spirv/builder_function_variable_test.cc
index 1f4217a..7fae8b3 100644
--- a/src/writer/spirv/builder_function_variable_test.cc
+++ b/src/writer/spirv/builder_function_variable_test.cc
@@ -45,7 +45,7 @@
TEST_F(BuilderTest, FunctionVar_WithConstantConstructor) {
auto* init = vec3<f32>(1.f, 1.f, 3.f);
- auto* v = Global("var", ty.f32(), ast::StorageClass::kOutput, init);
+ auto* v = Global("var", ty.vec3<f32>(), ast::StorageClass::kOutput, init);
spirv::Builder& b = Build();
@@ -60,8 +60,8 @@
%3 = OpConstant %2 1
%4 = OpConstant %2 3
%5 = OpConstantComposite %1 %3 %3 %4
-%7 = OpTypePointer Function %2
-%8 = OpConstantNull %2
+%7 = OpTypePointer Function %1
+%8 = OpConstantNull %1
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
R"(%6 = OpVariable %7 Function %8
diff --git a/src/writer/spirv/builder_global_variable_test.cc b/src/writer/spirv/builder_global_variable_test.cc
index 23e8b78..79b8d5c 100644
--- a/src/writer/spirv/builder_global_variable_test.cc
+++ b/src/writer/spirv/builder_global_variable_test.cc
@@ -57,7 +57,7 @@
TEST_F(BuilderTest, GlobalVar_WithConstructor) {
auto* init = vec3<f32>(1.f, 1.f, 3.f);
- auto* v = Global("var", ty.f32(), ast::StorageClass::kOutput, init);
+ auto* v = Global("var", ty.vec3<f32>(), ast::StorageClass::kOutput, init);
spirv::Builder& b = Build();
@@ -71,7 +71,7 @@
%3 = OpConstant %2 1
%4 = OpConstant %2 3
%5 = OpConstantComposite %1 %3 %3 %4
-%7 = OpTypePointer Output %2
+%7 = OpTypePointer Output %1
%6 = OpVariable %7 Output %5
)");
}
@@ -79,7 +79,7 @@
TEST_F(BuilderTest, GlobalVar_Const) {
auto* init = vec3<f32>(1.f, 1.f, 3.f);
- auto* v = GlobalConst("var", ty.f32(), init);
+ auto* v = GlobalConst("var", ty.vec3<f32>(), init);
spirv::Builder& b = Build();
@@ -99,7 +99,7 @@
TEST_F(BuilderTest, GlobalVar_Complex_Constructor) {
auto* init = vec3<f32>(ast::ExpressionList{Expr(1.f), Expr(2.f), Expr(3.f)});
- auto* v = GlobalConst("var", ty.f32(), init);
+ auto* v = GlobalConst("var", ty.vec3<f32>(), init);
spirv::Builder& b = Build();
@@ -118,7 +118,7 @@
TEST_F(BuilderTest, GlobalVar_Complex_ConstructorWithExtract) {
auto* init = vec3<f32>(vec2<f32>(1.f, 2.f), 3.f);
- auto* v = GlobalConst("var", ty.f32(), init);
+ auto* v = GlobalConst("var", ty.vec3<f32>(), init);
spirv::Builder& b = Build();
diff --git a/src/writer/spirv/builder_ident_expression_test.cc b/src/writer/spirv/builder_ident_expression_test.cc
index 55f2a12..821f9e5 100644
--- a/src/writer/spirv/builder_ident_expression_test.cc
+++ b/src/writer/spirv/builder_ident_expression_test.cc
@@ -25,7 +25,7 @@
TEST_F(BuilderTest, IdentifierExpression_GlobalConst) {
auto* init = vec3<f32>(1.f, 1.f, 3.f);
- auto* v = GlobalConst("var", ty.f32(), init);
+ auto* v = GlobalConst("var", ty.vec3<f32>(), init);
auto* expr = Expr("var");
WrapInFunction(expr);