Resolver: Move common logic into Variable()

Variable() is called for globals, locals and parameters. Much of the logic is the same.

Move all the common logic down into Variable(). This:
* Removes some yucky default parameters
* Adds type validation that was missing for globals (broken tests fixed)
* Gives me a single place to implement the Reference type wrapping

Bug: tint:727
Change-Id: I70f4a3603d7fa781da938508aa2a1bc80ec15d77
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/50580
Auto-Submit: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 4c6eb94..8095bea 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -413,20 +413,57 @@
   return s;
 }
 
-Resolver::VariableInfo* Resolver::Variable(
-    ast::Variable* var,
-    const sem::Type* type, /* = nullptr */
-    std::string type_name /* = "" */) {
-  auto it = variable_to_info_.find(var);
-  if (it != variable_to_info_.end()) {
-    return it->second;
+Resolver::VariableInfo* Resolver::Variable(ast::Variable* var,
+                                           bool is_parameter) {
+  if (variable_to_info_.count(var)) {
+    TINT_ICE(diagnostics_) << "Variable "
+                           << builder_->Symbols().NameFor(var->symbol())
+                           << " already resolved";
+    return nullptr;
   }
 
-  if (type == nullptr && var->type()) {
-    type = Type(var->type());
-    type_name = var->type()->FriendlyName(builder_->Symbols());
+  // If the variable has a declared type, resolve it.
+  std::string type_name;
+  const sem::Type* type = nullptr;
+  if (auto* ty = var->type()) {
+    type_name = ty->FriendlyName(builder_->Symbols());
+    type = Type(ty);
+    if (!type) {
+      return nullptr;
+    }
   }
-  if (type == nullptr) {
+
+  // Does the variable have a constructor?
+  if (auto* ctor = var->constructor()) {
+    Mark(var->constructor());
+    if (!Expression(var->constructor())) {
+      return nullptr;
+    }
+
+    // Fetch the constructor's type
+    auto* rhs_type = TypeOf(ctor);
+    if (!rhs_type) {
+      return nullptr;
+    }
+
+    // If the variable has no declared type, infer it from the RHS
+    if (type == nullptr) {
+      type_name = TypeNameOf(ctor);
+      type = rhs_type->UnwrapPtr();
+    }
+
+    if (!IsValidAssignment(type, rhs_type)) {
+      diagnostics_.add_error(
+          "variable of type '" + type_name +
+              "' cannot be initialized with a value of type '" +
+              TypeNameOf(ctor) + "'",
+          var->source());
+      return nullptr;
+    }
+  } else if (var->is_const() && !is_parameter &&
+             !ast::HasDecoration<ast::OverrideDecoration>(var->decorations())) {
+    diagnostics_.add_error("let declarations must have initializers",
+                           var->source());
     return nullptr;
   }
 
@@ -446,7 +483,7 @@
     return false;
   }
 
-  auto* info = Variable(var);
+  auto* info = Variable(var, /* is_parameter */ false);
   if (!info) {
     return false;
   }
@@ -472,20 +509,6 @@
     info->binding_point = {bp.group->value(), bp.binding->value()};
   }
 
-  if (var->has_constructor()) {
-    Mark(var->constructor());
-    if (!Expression(var->constructor())) {
-      return false;
-    }
-  } else {
-    if (var->is_const() &&
-        !ast::HasDecoration<ast::OverrideDecoration>(var->decorations())) {
-      diagnostics_.add_error("let declarations must have initializers",
-                             var->source());
-      return false;
-    }
-  }
-
   if (!ValidateGlobalVariable(info)) {
     return false;
   }
@@ -1020,7 +1043,7 @@
   variable_stack_.push_scope();
   for (auto* param : func->params()) {
     Mark(param);
-    auto* param_info = Variable(param);
+    auto* param_info = Variable(param, /* is_parameter */ true);
     if (!param_info) {
       return false;
     }
@@ -2072,17 +2095,6 @@
   ast::Variable* var = stmt->variable();
   Mark(var);
 
-  // If the variable has a declared type, resolve it.
-  std::string type_name;
-  const sem::Type* type = nullptr;
-  if (auto* ast_ty = var->type()) {
-    type_name = ast_ty->FriendlyName(builder_->Symbols());
-    type = Type(ast_ty);
-    if (!type) {
-      return false;
-    }
-  }
-
   bool is_global = false;
   if (variable_stack_.get(var->symbol(), nullptr, &is_global)) {
     const char* error_code = is_global ? "v-0013" : "v-0014";
@@ -2093,33 +2105,9 @@
     return false;
   }
 
-  if (auto* ctor = stmt->variable()->constructor()) {
-    Mark(ctor);
-    if (!Expression(ctor)) {
-      return false;
-    }
-    auto* rhs_type = TypeOf(ctor);
-
-    // If the variable has no type, infer it from the rhs
-    if (type == nullptr) {
-      type_name = TypeNameOf(ctor);
-      type = rhs_type->UnwrapPtr();
-    }
-
-    if (!IsValidAssignment(type, rhs_type)) {
-      diagnostics_.add_error(
-          "variable of type '" + type_name +
-              "' cannot be initialized with a value of type '" +
-              TypeNameOf(ctor) + "'",
-          stmt->source());
-      return false;
-    }
-  } else {
-    if (stmt->variable()->is_const()) {
-      diagnostics_.add_error("let declarations must have initializers",
-                             var->source());
-      return false;
-    }
+  auto* info = Variable(var, /* is_parameter */ false);
+  if (!info) {
+    return false;
   }
 
   for (auto* deco : var->decorations()) {
@@ -2127,13 +2115,6 @@
     Mark(deco);
   }
 
-  auto* info = Variable(var, type, type_name);
-  if (!info) {
-    return false;
-  }
-  // TODO(bclayton): Remove this and fix tests. We're overriding the semantic
-  // type stored in info->type here with a possibly non-canonicalized type.
-  info->type = const_cast<sem::Type*>(type);
   variable_stack_.set(var->symbol(), info);
   current_block_->decls.push_back(var);
 
@@ -2212,8 +2193,7 @@
     TINT_ASSERT(type.sem);
   }
   if (expr_info_.count(expr)) {
-    TINT_ICE(builder_->Diagnostics())
-        << "SetType() called twice for the same expression";
+    TINT_ICE(diagnostics_) << "SetType() called twice for the same expression";
   }
   expr_info_.emplace(expr, ExpressionInfo{type, type_name, current_statement_});
 }
@@ -2260,9 +2240,8 @@
       } else {
         auto* sem_user = sem_expr->As<sem::VariableUser>();
         if (!sem_user) {
-          TINT_ICE(builder_->Diagnostics())
-              << "expected sem::VariableUser, got "
-              << sem_expr->TypeInfo().name;
+          TINT_ICE(diagnostics_) << "expected sem::VariableUser, got "
+                                 << sem_expr->TypeInfo().name;
         }
         sem_var->AddUser(sem_user);
       }
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index 02dce4f..95a7365 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -262,13 +262,11 @@
 
   /// @returns the VariableInfo for the variable `var`, building it if it hasn't
   /// been constructed already. If an error is raised, nullptr is returned.
+  /// @note this method does not resolve the decorations as these are
+  /// context-dependent (global, local, parameter)
   /// @param var the variable to create or return the `VariableInfo` for
-  /// @param type optional type of `var` to use instead of `var->type()`.
-  /// @param type_name optional type name of `var` to use instead of
-  /// `var->type()->FriendlyName()`.
-  VariableInfo* Variable(ast::Variable* var,
-                         const sem::Type* type = nullptr,
-                         std::string type_name = "");
+  /// @param is_parameter true if the variable represents a parameter
+  VariableInfo* Variable(ast::Variable* var, bool is_parameter);
 
   /// Records the storage class usage for the given type, and any transient
   /// dependencies of the type. Validates that the type can be used for the
diff --git a/src/resolver/type_validation_test.cc b/src/resolver/type_validation_test.cc
index 82a231e..a97d677 100644
--- a/src/resolver/type_validation_test.cc
+++ b/src/resolver/type_validation_test.cc
@@ -90,8 +90,8 @@
   // const<in> global_var: f32;
   AST().AddGlobalVariable(
       create<ast::Variable>(Source{{12, 34}}, Symbols().Register("global_var"),
-                            ast::StorageClass::kInput, ty.f32(), true, nullptr,
-                            ast::DecorationList{}));
+                            ast::StorageClass::kInput, ty.f32(), true,
+                            Expr(1.23f), ast::DecorationList{}));
 
   EXPECT_FALSE(r()->Resolve());
   EXPECT_EQ(r()->error(),
@@ -113,7 +113,7 @@
   Global("global_var0", ty.f32(), ast::StorageClass::kPrivate, Expr(0.1f));
 
   Global(Source{{12, 34}}, "global_var1", ty.f32(), ast::StorageClass::kPrivate,
-         Expr(0));
+         Expr(1.0f));
 
   EXPECT_TRUE(r()->Resolve()) << r()->error();
 }
diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc
index 53c087c..87a9608 100644
--- a/src/resolver/validation_test.cc
+++ b/src/resolver/validation_test.cc
@@ -164,10 +164,8 @@
 TEST_F(ResolverValidationTest,
        Stmt_VariableDecl_MismatchedTypeScalarConstructor) {
   u32 unsigned_value = 2u;  // Type does not match variable type
-  auto* var =
-      Var("my_var", ty.i32(), ast::StorageClass::kNone, Expr(unsigned_value));
-
-  auto* decl = Decl(Source{{{3, 3}, {3, 22}}}, var);
+  auto* decl = Decl(Var(Source{{3, 3}}, "my_var", ty.i32(),
+                        ast::StorageClass::kNone, Expr(unsigned_value)));
   WrapInFunction(decl);
 
   EXPECT_FALSE(r()->Resolve());
@@ -181,10 +179,8 @@
   auto* my_int = ty.alias("MyInt", ty.i32());
   AST().AddConstructedType(my_int);
   u32 unsigned_value = 2u;  // Type does not match variable type
-  auto* var =
-      Var("my_var", my_int, ast::StorageClass::kNone, Expr(unsigned_value));
-
-  auto* decl = Decl(Source{{{3, 3}, {3, 22}}}, var);
+  auto* decl = Decl(Var(Source{{3, 3}}, "my_var", my_int,
+                        ast::StorageClass::kNone, Expr(unsigned_value)));
   WrapInFunction(decl);
 
   EXPECT_FALSE(r()->Resolve());
diff --git a/src/writer/spirv/builder_function_variable_test.cc b/src/writer/spirv/builder_function_variable_test.cc
index 1f4217a..7fae8b3 100644
--- a/src/writer/spirv/builder_function_variable_test.cc
+++ b/src/writer/spirv/builder_function_variable_test.cc
@@ -45,7 +45,7 @@
 TEST_F(BuilderTest, FunctionVar_WithConstantConstructor) {
   auto* init = vec3<f32>(1.f, 1.f, 3.f);
 
-  auto* v = Global("var", ty.f32(), ast::StorageClass::kOutput, init);
+  auto* v = Global("var", ty.vec3<f32>(), ast::StorageClass::kOutput, init);
 
   spirv::Builder& b = Build();
 
@@ -60,8 +60,8 @@
 %3 = OpConstant %2 1
 %4 = OpConstant %2 3
 %5 = OpConstantComposite %1 %3 %3 %4
-%7 = OpTypePointer Function %2
-%8 = OpConstantNull %2
+%7 = OpTypePointer Function %1
+%8 = OpConstantNull %1
 )");
   EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
             R"(%6 = OpVariable %7 Function %8
diff --git a/src/writer/spirv/builder_global_variable_test.cc b/src/writer/spirv/builder_global_variable_test.cc
index 23e8b78..79b8d5c 100644
--- a/src/writer/spirv/builder_global_variable_test.cc
+++ b/src/writer/spirv/builder_global_variable_test.cc
@@ -57,7 +57,7 @@
 TEST_F(BuilderTest, GlobalVar_WithConstructor) {
   auto* init = vec3<f32>(1.f, 1.f, 3.f);
 
-  auto* v = Global("var", ty.f32(), ast::StorageClass::kOutput, init);
+  auto* v = Global("var", ty.vec3<f32>(), ast::StorageClass::kOutput, init);
 
   spirv::Builder& b = Build();
 
@@ -71,7 +71,7 @@
 %3 = OpConstant %2 1
 %4 = OpConstant %2 3
 %5 = OpConstantComposite %1 %3 %3 %4
-%7 = OpTypePointer Output %2
+%7 = OpTypePointer Output %1
 %6 = OpVariable %7 Output %5
 )");
 }
@@ -79,7 +79,7 @@
 TEST_F(BuilderTest, GlobalVar_Const) {
   auto* init = vec3<f32>(1.f, 1.f, 3.f);
 
-  auto* v = GlobalConst("var", ty.f32(), init);
+  auto* v = GlobalConst("var", ty.vec3<f32>(), init);
 
   spirv::Builder& b = Build();
 
@@ -99,7 +99,7 @@
 TEST_F(BuilderTest, GlobalVar_Complex_Constructor) {
   auto* init = vec3<f32>(ast::ExpressionList{Expr(1.f), Expr(2.f), Expr(3.f)});
 
-  auto* v = GlobalConst("var", ty.f32(), init);
+  auto* v = GlobalConst("var", ty.vec3<f32>(), init);
 
   spirv::Builder& b = Build();
 
@@ -118,7 +118,7 @@
 TEST_F(BuilderTest, GlobalVar_Complex_ConstructorWithExtract) {
   auto* init = vec3<f32>(vec2<f32>(1.f, 2.f), 3.f);
 
-  auto* v = GlobalConst("var", ty.f32(), init);
+  auto* v = GlobalConst("var", ty.vec3<f32>(), init);
 
   spirv::Builder& b = Build();
 
diff --git a/src/writer/spirv/builder_ident_expression_test.cc b/src/writer/spirv/builder_ident_expression_test.cc
index 55f2a12..821f9e5 100644
--- a/src/writer/spirv/builder_ident_expression_test.cc
+++ b/src/writer/spirv/builder_ident_expression_test.cc
@@ -25,7 +25,7 @@
 TEST_F(BuilderTest, IdentifierExpression_GlobalConst) {
   auto* init = vec3<f32>(1.f, 1.f, 3.f);
 
-  auto* v = GlobalConst("var", ty.f32(), init);
+  auto* v = GlobalConst("var", ty.vec3<f32>(), init);
 
   auto* expr = Expr("var");
   WrapInFunction(expr);