validator: Support assignment through pointer

Also enable a test to check assigning to scalar literal.

Fixed: tint:419
Change-Id: Ic565af22c4ef6b60c41faaf9fabe3bd55fe48d2d
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/37961
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: David Neto <dneto@google.com>
Auto-Submit: David Neto <dneto@google.com>
diff --git a/src/validator/validator_impl.cc b/src/validator/validator_impl.cc
index 3461e51..15dfbbd 100644
--- a/src/validator/validator_impl.cc
+++ b/src/validator/validator_impl.cc
@@ -30,6 +30,7 @@
 #include "src/ast/type/array_type.h"
 #include "src/ast/type/i32_type.h"
 #include "src/ast/type/matrix_type.h"
+#include "src/ast/type/pointer_type.h"
 #include "src/ast/type/struct_type.h"
 #include "src/ast/type/u32_type.h"
 #include "src/ast/type/vector_type.h"
@@ -277,6 +278,10 @@
               "redeclared identifier '" + module_.SymbolToName(symbol) + "'");
     return false;
   }
+  // TODO(dneto): Check type compatibility of the initializer.
+  //  - if it's non-constant, then is storable or can be dereferenced to be
+  //    storable.
+  //  - types match or the RHS can be dereferenced to equal the LHS type.
   variable_stack_.set(symbol, decl->variable());
   if (auto* arr =
           decl->variable()->type()->UnwrapAll()->As<ast::type::Array>()) {
@@ -429,57 +434,78 @@
   return true;
 }
 
-bool ValidatorImpl::ValidateAssign(const ast::AssignmentStatement* a) {
-  if (!a) {
+bool ValidatorImpl::ValidateBadAssignmentToIdentifier(
+    const ast::AssignmentStatement* assign) {
+  auto* ident = assign->lhs()->As<ast::IdentifierExpression>();
+  if (!ident) {
+    // It wasn't an identifier in the first place.
+    return true;
+  }
+  ast::Variable* var;
+  if (variable_stack_.get(ident->symbol(), &var)) {
+    // Give a nicer message if the LHS of the assignment is a const identifier.
+    // It's likely to be a common programmer error.
+    if (var->is_const()) {
+      add_error(assign->source(), "v-0021",
+                "cannot re-assign a constant: '" +
+                    module_.SymbolToName(ident->symbol()) + "'");
+      return false;
+    }
+  } else {
+    // The identifier is not defined. This should already have been caught
+    // when validating the subexpression.
+    add_error(
+        ident->source(), "v-0006",
+        "'" + module_.SymbolToName(ident->symbol()) + "' is not declared");
     return false;
   }
-  if (!(ValidateConstant(a))) {
-    return false;
-  }
-  if (!(ValidateExpression(a->lhs()) && ValidateExpression(a->rhs()))) {
-    return false;
-  }
-  if (!ValidateResultTypes(a)) {
-    return false;
-  }
-
   return true;
 }
 
-bool ValidatorImpl::ValidateConstant(const ast::AssignmentStatement* assign) {
+bool ValidatorImpl::ValidateAssign(const ast::AssignmentStatement* assign) {
   if (!assign) {
     return false;
   }
-
-  if (auto* ident = assign->lhs()->As<ast::IdentifierExpression>()) {
-    ast::Variable* var;
-    if (variable_stack_.get(ident->symbol(), &var)) {
-      if (var->is_const()) {
-        add_error(assign->source(), "v-0021",
-                  "cannot re-assign a constant: '" +
-                      module_.SymbolToName(ident->symbol()) + "'");
-        return false;
-      }
+  auto* lhs = assign->lhs();
+  auto* rhs = assign->rhs();
+  if (!ValidateExpression(lhs)) {
+    return false;
+  }
+  if (!ValidateExpression(rhs)) {
+    return false;
+  }
+  // Pointers are not storable in WGSL, but the right-hand side must be
+  // storable. The raw right-hand side might be a pointer value which must be
+  // loaded (dereferenced) to provide the value to be stored.
+  auto* rhs_result_type = rhs->result_type()->UnwrapAll();
+  if (!IsStorable(rhs_result_type)) {
+    add_error(assign->source(), "v-000x",
+              "invalid assignment: right-hand-side is not storable: " +
+                  rhs->result_type()->type_name());
+    return false;
+  }
+  auto* lhs_result_type = lhs->result_type()->UnwrapIfNeeded();
+  if (auto* lhs_reference_type = As<ast::type::Pointer>(lhs_result_type)) {
+    auto* lhs_store_type = lhs_reference_type->type()->UnwrapIfNeeded();
+    if (lhs_store_type != rhs_result_type) {
+      add_error(assign->source(), "v-000x",
+                "invalid assignment: can't assign value of type '" +
+                    rhs_result_type->type_name() + "' to '" +
+                    lhs_store_type->type_name() + "'");
+      return false;
     }
-  }
-  return true;
-}
-
-bool ValidatorImpl::ValidateResultTypes(const ast::AssignmentStatement* a) {
-  if (!a->lhs()->result_type() || !a->rhs()->result_type()) {
-    add_error(a->source(), "result_type() is nullptr");
+  } else {
+    if (!ValidateBadAssignmentToIdentifier(assign)) {
+      return false;
+    }
+    // Issue a generic error.
+    add_error(
+        assign->source(), "v-000x",
+        "invalid assignment: left-hand-side does not reference storage: " +
+            lhs->result_type()->type_name());
     return false;
   }
 
-  auto* lhs_result_type = a->lhs()->result_type()->UnwrapAll();
-  auto* rhs_result_type = a->rhs()->result_type()->UnwrapAll();
-  if (lhs_result_type != rhs_result_type) {
-    // TODO(sarahM0): figur out what should be the error number.
-    add_error(a->source(), "v-000x",
-              "invalid assignment of '" + lhs_result_type->type_name() +
-                  "' to '" + rhs_result_type->type_name() + "'");
-    return false;
-  }
   return true;
 }
 
diff --git a/src/validator/validator_impl.h b/src/validator/validator_impl.h
index d619b82..f0a1b5d 100644
--- a/src/validator/validator_impl.h
+++ b/src/validator/validator_impl.h
@@ -100,6 +100,12 @@
   /// @param assign the assignment to check
   /// @returns true if the validation was successful
   bool ValidateAssign(const ast::AssignmentStatement* assign);
+  /// Validates a bad assignment to an identifier. Issues an error
+  /// and returns false if the left hand side is an identifier.
+  /// @param assign the assignment to check
+  /// @returns true if the LHS of theassignment is not an identifier expression
+  bool ValidateBadAssignmentToIdentifier(
+      const ast::AssignmentStatement* assign);
   /// Validates an expression
   /// @param expr the expression to check
   /// @return true if the expression is valid
@@ -108,14 +114,6 @@
   /// @param ident the identifer to check if its in the scope
   /// @return true if idnet was defined
   bool ValidateIdentifier(const ast::IdentifierExpression* ident);
-  /// Validates if the input follows type checking rules
-  /// @param assign the assignment to check
-  /// @returns ture if successful
-  bool ValidateResultTypes(const ast::AssignmentStatement* assign);
-  /// Validate v-0021: Cannot re-assign a constant
-  /// @param assign is the assigment to check if its lhs is a const
-  /// @returns false if lhs of assign is a constant identifier
-  bool ValidateConstant(const ast::AssignmentStatement* assign);
   /// Validates declaration name uniqueness
   /// @param decl is the new declaration to be added
   /// @returns true if no previous declaration with the `decl` 's name
@@ -154,6 +152,11 @@
   /// @returns true if the given type is storable.
   bool IsStorable(ast::type::Type* type);
 
+  /// Testing method to inserting a given variable into the current scope.
+  void RegisterVariableForTesting(ast::Variable* var) {
+    variable_stack_.set(var->symbol(), var);
+  }
+
  private:
   const ast::Module& module_;
   diag::List diags_;
diff --git a/src/validator/validator_test.cc b/src/validator/validator_test.cc
index 1aa10f5..de38988 100644
--- a/src/validator/validator_test.cc
+++ b/src/validator/validator_test.cc
@@ -60,18 +60,27 @@
 
 class ValidatorTest : public ValidatorTestHelper, public testing::Test {};
 
-TEST_F(ValidatorTest, DISABLED_AssignToScalar_Fail) {
+TEST_F(ValidatorTest, AssignToScalar_Fail) {
+  // var my_var : i32 = 2;
   // 1 = my_var;
 
+  auto* var = Var("my_var", ast::StorageClass::kNone, ty.i32, Expr(2),
+                  ast::VariableDecorationList{});
+
   auto* lhs = Expr(1);
   auto* rhs = Expr("my_var");
   SetSource(Source{Source::Location{12, 34}});
-  create<ast::AssignmentStatement>(lhs, rhs);
+  auto* assign = create<ast::AssignmentStatement>(lhs, rhs);
+  RegisterVariable(var);
 
+  EXPECT_TRUE(td()->DetermineResultType(assign));
   // TODO(sarahM0): Invalidate assignment to scalar.
+  EXPECT_FALSE(v()->ValidateAssign(assign));
   ASSERT_TRUE(v()->has_error());
   // TODO(sarahM0): figure out what should be the error number.
-  EXPECT_EQ(v()->error(), "12:34 v-000x: invalid assignment");
+  EXPECT_EQ(v()->error(),
+            "12:34 v-000x: invalid assignment: left-hand-side does not "
+            "reference storage: __i32");
 }
 
 TEST_F(ValidatorTest, UsingUndefinedVariable_Fail) {
@@ -116,11 +125,75 @@
 
   auto* assign = create<ast::AssignmentStatement>(
       Source{Source::Location{12, 34}}, lhs, rhs);
-  td()->RegisterVariableForTesting(var);
+  RegisterVariable(var);
   EXPECT_TRUE(td()->DetermineResultType(assign)) << td()->error();
   ASSERT_NE(lhs->result_type(), nullptr);
   ASSERT_NE(rhs->result_type(), nullptr);
-  EXPECT_TRUE(v()->ValidateResultTypes(assign));
+  EXPECT_TRUE(v()->ValidateAssign(assign)) << v()->error();
+}
+
+TEST_F(ValidatorTest, AssignCompatibleTypesThroughAlias_Pass) {
+  // alias myint = i32;
+  // var a :myint = 2;
+  // a = 2
+  auto* myint = ty.alias("myint", ty.i32);
+  auto* var = Var("a", ast::StorageClass::kNone, myint, Expr(2),
+                  ast::VariableDecorationList{});
+
+  auto* lhs = Expr("a");
+  auto* rhs = Expr(2);
+
+  auto* assign = create<ast::AssignmentStatement>(
+      Source{Source::Location{12, 34}}, lhs, rhs);
+  RegisterVariable(var);
+  EXPECT_TRUE(td()->DetermineResultType(assign)) << td()->error();
+  ASSERT_NE(lhs->result_type(), nullptr);
+  ASSERT_NE(rhs->result_type(), nullptr);
+  EXPECT_TRUE(v()->ValidateAssign(assign)) << v()->error();
+}
+
+TEST_F(ValidatorTest, AssignCompatibleTypesInferRHSLoad_Pass) {
+  // var a :i32 = 2;
+  // var b :i32 = 3;
+  // a = b;
+  auto* var_a = Var("a", ast::StorageClass::kNone, ty.i32, Expr(2),
+                    ast::VariableDecorationList{});
+  auto* var_b = Var("b", ast::StorageClass::kNone, ty.i32, Expr(3),
+                    ast::VariableDecorationList{});
+
+  auto* lhs = Expr("a");
+  auto* rhs = Expr("b");
+
+  auto* assign = create<ast::AssignmentStatement>(
+      Source{Source::Location{12, 34}}, lhs, rhs);
+  RegisterVariable(var_a);
+  RegisterVariable(var_b);
+  EXPECT_TRUE(td()->DetermineResultType(assign)) << td()->error();
+  ASSERT_NE(lhs->result_type(), nullptr);
+  ASSERT_NE(rhs->result_type(), nullptr);
+  EXPECT_TRUE(v()->ValidateAssign(assign)) << v()->error();
+}
+
+TEST_F(ValidatorTest, AssignThroughPointer_Pass) {
+  // var a :i32;
+  // const b : ptr<function,i32> = a;
+  // b = 2;
+  const auto func = ast::StorageClass::kFunction;
+  auto* var_a = Var("a", func, ty.i32, Expr(2), {});
+  auto* var_b = Const("b", ast::StorageClass::kNone, ty.pointer<int>(func),
+                      Expr("a"), {});
+
+  auto* lhs = Expr("b");
+  auto* rhs = Expr(2);
+
+  auto* assign = create<ast::AssignmentStatement>(
+      Source{Source::Location{12, 34}}, lhs, rhs);
+  RegisterVariable(var_a);
+  RegisterVariable(var_b);
+  EXPECT_TRUE(td()->DetermineResultType(assign)) << td()->error();
+  ASSERT_NE(lhs->result_type(), nullptr);
+  ASSERT_NE(rhs->result_type(), nullptr);
+  EXPECT_TRUE(v()->ValidateAssign(assign)) << v()->error();
 }
 
 TEST_F(ValidatorTest, AssignIncompatibleTypes_Fail) {
@@ -137,16 +210,42 @@
 
   auto* assign = create<ast::AssignmentStatement>(
       Source{Source::Location{12, 34}}, lhs, rhs);
-  td()->RegisterVariableForTesting(var);
+  RegisterVariable(var);
   EXPECT_TRUE(td()->DetermineResultType(assign)) << td()->error();
   ASSERT_NE(lhs->result_type(), nullptr);
   ASSERT_NE(rhs->result_type(), nullptr);
 
-  EXPECT_FALSE(v()->ValidateResultTypes(assign));
+  EXPECT_FALSE(v()->ValidateAssign(assign));
   ASSERT_TRUE(v()->has_error());
   // TODO(sarahM0): figure out what should be the error number.
   EXPECT_EQ(v()->error(),
-            "12:34 v-000x: invalid assignment of '__i32' to '__f32'");
+            "12:34 v-000x: invalid assignment: can't assign value of type "
+            "'__f32' to '__i32'");
+}
+
+TEST_F(ValidatorTest, AssignThroughPointerWrongeStoreType_Fail) {
+  // var a :f32;
+  // const b : ptr<function,f32> = a;
+  // b = 2;
+  const auto priv = ast::StorageClass::kFunction;
+  auto* var_a = Var("a", priv, ty.f32, Expr(2), {});
+  auto* var_b = Const("b", ast::StorageClass::kNone, ty.pointer<float>(priv),
+                      Expr("a"), {});
+
+  auto* lhs = Expr("a");
+  auto* rhs = Expr(2);
+
+  auto* assign = create<ast::AssignmentStatement>(
+      Source{Source::Location{12, 34}}, lhs, rhs);
+  RegisterVariable(var_a);
+  RegisterVariable(var_b);
+  EXPECT_TRUE(td()->DetermineResultType(assign)) << td()->error();
+  ASSERT_NE(lhs->result_type(), nullptr);
+  ASSERT_NE(rhs->result_type(), nullptr);
+  EXPECT_FALSE(v()->ValidateAssign(assign));
+  EXPECT_EQ(v()->error(),
+            "12:34 v-000x: invalid assignment: can't assign value of type "
+            "'__i32' to '__f32'");
 }
 
 TEST_F(ValidatorTest, AssignCompatibleTypesInBlockStatement_Pass) {
@@ -199,7 +298,8 @@
   ASSERT_TRUE(v()->has_error());
   // TODO(sarahM0): figure out what should be the error number.
   EXPECT_EQ(v()->error(),
-            "12:34 v-000x: invalid assignment of '__i32' to '__f32'");
+            "12:34 v-000x: invalid assignment: can't assign value of type "
+            "'__f32' to '__i32'");
 }
 
 TEST_F(ValidatorTest, GlobalVariableWithStorageClass_Pass) {
diff --git a/src/validator/validator_test_helper.h b/src/validator/validator_test_helper.h
index bc15b32..a7f4374 100644
--- a/src/validator/validator_test_helper.h
+++ b/src/validator/validator_test_helper.h
@@ -39,6 +39,13 @@
   /// @returns a pointer to the type_determiner object
   TypeDeterminer* td() const { return td_.get(); }
 
+  /// Inserts a variable into the current scope.
+  /// @param var the variable to register.
+  void RegisterVariable(ast::Variable* var) {
+    v_->RegisterVariableForTesting(var);
+    td_->RegisterVariableForTesting(var);
+  }
+
  private:
   std::unique_ptr<ValidatorImpl> v_;
   std::unique_ptr<TypeDeterminer> td_;