resolver: Validate compound assignment statements

Reuse the logic for resolving binary operator result types that was
implemented for binary expressions. This validates that the LHS and
RHS are compatible for the target operator. We then try to match the
resolved result type against the LHS store type.

Bug: tint:1325

Change-Id: If80a883079bb71fa6c4eb5545654279fefffacb4
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/74362
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 4d30266..731c19f 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -732,6 +732,7 @@
     resolver/builtin_validation_test.cc
     resolver/call_test.cc
     resolver/call_validation_test.cc
+    resolver/compound_assignment_validation_test.cc
     resolver/compound_statement_test.cc
     resolver/control_block_validation_test.cc
     resolver/attribute_validation_test.cc
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index 1d40cb1..eb87d3e 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -32,6 +32,7 @@
 #include "src/tint/ast/call_expression.h"
 #include "src/tint/ast/call_statement.h"
 #include "src/tint/ast/case_statement.h"
+#include "src/tint/ast/compound_assignment_statement.h"
 #include "src/tint/ast/continue_statement.h"
 #include "src/tint/ast/depth_multisampled_texture.h"
 #include "src/tint/ast/depth_texture.h"
@@ -2265,6 +2266,40 @@
         Expr(std::forward<RhsExpressionInit>(rhs)));
   }
 
+  /// Creates a ast::CompoundAssignmentStatement with input lhs and rhs
+  /// expressions, and a binary operator.
+  /// @param source the source information
+  /// @param lhs the left hand side expression initializer
+  /// @param rhs the right hand side expression initializer
+  /// @param op the binary operator
+  /// @returns the compound assignment statement pointer
+  template <typename LhsExpressionInit, typename RhsExpressionInit>
+  const ast::CompoundAssignmentStatement* CompoundAssign(
+      const Source& source,
+      LhsExpressionInit&& lhs,
+      RhsExpressionInit&& rhs,
+      ast::BinaryOp op) {
+    return create<ast::CompoundAssignmentStatement>(
+        source, Expr(std::forward<LhsExpressionInit>(lhs)),
+        Expr(std::forward<RhsExpressionInit>(rhs)), op);
+  }
+
+  /// Creates a ast::CompoundAssignmentStatement with input lhs and rhs
+  /// expressions, and a binary operator.
+  /// @param lhs the left hand side expression initializer
+  /// @param rhs the right hand side expression initializer
+  /// @param op the binary operator
+  /// @returns the compound assignment statement pointer
+  template <typename LhsExpressionInit, typename RhsExpressionInit>
+  const ast::CompoundAssignmentStatement* CompoundAssign(
+      LhsExpressionInit&& lhs,
+      RhsExpressionInit&& rhs,
+      ast::BinaryOp op) {
+    return create<ast::CompoundAssignmentStatement>(
+        Expr(std::forward<LhsExpressionInit>(lhs)),
+        Expr(std::forward<RhsExpressionInit>(rhs)), op);
+  }
+
   /// Creates a ast::LoopStatement with input body and optional continuing
   /// @param source the source information
   /// @param body the loop body
diff --git a/src/tint/resolver/compound_assignment_validation_test.cc b/src/tint/resolver/compound_assignment_validation_test.cc
new file mode 100644
index 0000000..06ff43c
--- /dev/null
+++ b/src/tint/resolver/compound_assignment_validation_test.cc
@@ -0,0 +1,300 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include "gmock/gmock.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/storage_texture_type.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+using ResolverCompoundAssignmentValidationTest = ResolverTest;
+
+TEST_F(ResolverCompoundAssignmentValidationTest, CompatibleTypes) {
+  // var a : i32 = 2;
+  // a += 2
+  auto* var = Var("a", ty.i32(), ast::StorageClass::kNone, Expr(2));
+  WrapInFunction(var,
+                 CompoundAssign(Source{{12, 34}}, "a", 2, ast::BinaryOp::kAdd));
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverCompoundAssignmentValidationTest, CompatibleTypesThroughAlias) {
+  // alias myint = i32;
+  // var a : myint = 2;
+  // a += 2
+  auto* myint = Alias("myint", ty.i32());
+  auto* var = Var("a", ty.Of(myint), ast::StorageClass::kNone, Expr(2));
+  WrapInFunction(var,
+                 CompoundAssign(Source{{12, 34}}, "a", 2, ast::BinaryOp::kAdd));
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverCompoundAssignmentValidationTest,
+       CompatibleTypesAssignThroughPointer) {
+  // var a : i32;
+  // let b : ptr<function,i32> = &a;
+  // *b += 2;
+  const auto func = ast::StorageClass::kFunction;
+  auto* var_a = Var("a", ty.i32(), func, Expr(2));
+  auto* var_b = Const("b", ty.pointer<int>(func), AddressOf(Expr("a")));
+  WrapInFunction(
+      var_a, var_b,
+      CompoundAssign(Source{{12, 34}}, Deref("b"), 2, ast::BinaryOp::kAdd));
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverCompoundAssignmentValidationTest, IncompatibleTypes) {
+  // {
+  //   var a : i32 = 2;
+  //   a += 2.3;
+  // }
+
+  auto* var = Var("a", ty.i32(), ast::StorageClass::kNone, Expr(2));
+
+  auto* assign =
+      CompoundAssign(Source{{12, 34}}, "a", 2.3f, ast::BinaryOp::kAdd);
+  WrapInFunction(var, assign);
+
+  ASSERT_FALSE(r()->Resolve());
+
+  EXPECT_EQ(r()->error(),
+            "12:34 error: compound assignment operand types are invalid: i32 "
+            "add f32");
+}
+
+TEST_F(ResolverCompoundAssignmentValidationTest, IncompatibleOp) {
+  // {
+  //   var a : f32 = 1.0;
+  //   a |= 2.0;
+  // }
+
+  auto* var = Var("a", ty.f32(), ast::StorageClass::kNone, Expr(1.f));
+
+  auto* assign =
+      CompoundAssign(Source{{12, 34}}, "a", 2.0f, ast::BinaryOp::kOr);
+  WrapInFunction(var, assign);
+
+  ASSERT_FALSE(r()->Resolve());
+
+  EXPECT_EQ(
+      r()->error(),
+      "12:34 error: compound assignment operand types are invalid: f32 or f32");
+}
+
+TEST_F(ResolverCompoundAssignmentValidationTest, VectorScalar_Pass) {
+  // {
+  //   var a : vec4<f32>;
+  //   a += 1.0;
+  // }
+
+  auto* var = Var("a", ty.vec4<f32>(), ast::StorageClass::kNone);
+
+  auto* assign =
+      CompoundAssign(Source{{12, 34}}, "a", 1.f, ast::BinaryOp::kAdd);
+  WrapInFunction(var, assign);
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverCompoundAssignmentValidationTest, ScalarVector_Fail) {
+  // {
+  //   var a : f32;
+  //   a += vec4<f32>();
+  // }
+
+  auto* var = Var("a", ty.f32(), ast::StorageClass::kNone);
+
+  auto* assign =
+      CompoundAssign(Source{{12, 34}}, "a", vec4<f32>(), ast::BinaryOp::kAdd);
+  WrapInFunction(var, assign);
+
+  ASSERT_FALSE(r()->Resolve());
+
+  EXPECT_EQ(r()->error(), "12:34 error: cannot assign 'vec4<f32>' to 'f32'");
+}
+
+TEST_F(ResolverCompoundAssignmentValidationTest, MatrixScalar_Pass) {
+  // {
+  //   var a : mat4x4<f32>;
+  //   a *= 2.0;
+  // }
+
+  auto* var = Var("a", ty.mat4x4<f32>(), ast::StorageClass::kNone);
+
+  auto* assign =
+      CompoundAssign(Source{{12, 34}}, "a", 2.f, ast::BinaryOp::kMultiply);
+  WrapInFunction(var, assign);
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverCompoundAssignmentValidationTest, ScalarMatrix_Fail) {
+  // {
+  //   var a : f32;
+  //   a *= mat4x4();
+  // }
+
+  auto* var = Var("a", ty.f32(), ast::StorageClass::kNone);
+
+  auto* assign = CompoundAssign(Source{{12, 34}}, "a", mat4x4<f32>(),
+                                ast::BinaryOp::kMultiply);
+  WrapInFunction(var, assign);
+
+  ASSERT_FALSE(r()->Resolve());
+
+  EXPECT_EQ(r()->error(), "12:34 error: cannot assign 'mat4x4<f32>' to 'f32'");
+}
+
+TEST_F(ResolverCompoundAssignmentValidationTest, VectorMatrix_Pass) {
+  // {
+  //   var a : vec4<f32>;
+  //   a *= mat4x4();
+  // }
+
+  auto* var = Var("a", ty.vec4<f32>(), ast::StorageClass::kNone);
+
+  auto* assign = CompoundAssign(Source{{12, 34}}, "a", mat4x4<f32>(),
+                                ast::BinaryOp::kMultiply);
+  WrapInFunction(var, assign);
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverCompoundAssignmentValidationTest, VectorMatrix_ColumnMismatch) {
+  // {
+  //   var a : vec4<f32>;
+  //   a *= mat4x2();
+  // }
+
+  auto* var = Var("a", ty.vec4<f32>(), ast::StorageClass::kNone);
+
+  auto* assign = CompoundAssign(Source{{12, 34}}, "a", mat4x2<f32>(),
+                                ast::BinaryOp::kMultiply);
+  WrapInFunction(var, assign);
+
+  ASSERT_FALSE(r()->Resolve());
+
+  EXPECT_EQ(r()->error(),
+            "12:34 error: compound assignment operand types are invalid: "
+            "vec4<f32> multiply mat4x2<f32>");
+}
+
+TEST_F(ResolverCompoundAssignmentValidationTest, VectorMatrix_ResultMismatch) {
+  // {
+  //   var a : vec4<f32>;
+  //   a *= mat2x4();
+  // }
+
+  auto* var = Var("a", ty.vec4<f32>(), ast::StorageClass::kNone);
+
+  auto* assign = CompoundAssign(Source{{12, 34}}, "a", mat2x4<f32>(),
+                                ast::BinaryOp::kMultiply);
+  WrapInFunction(var, assign);
+
+  ASSERT_FALSE(r()->Resolve());
+
+  EXPECT_EQ(r()->error(),
+            "12:34 error: cannot assign 'vec2<f32>' to 'vec4<f32>'");
+}
+
+TEST_F(ResolverCompoundAssignmentValidationTest, MatrixVector_Fail) {
+  // {
+  //   var a : mat4x4<f32>;
+  //   a *= vec4();
+  // }
+
+  auto* var = Var("a", ty.mat4x4<f32>(), ast::StorageClass::kNone);
+
+  auto* assign = CompoundAssign(Source{{12, 34}}, "a", vec4<f32>(),
+                                ast::BinaryOp::kMultiply);
+  WrapInFunction(var, assign);
+
+  ASSERT_FALSE(r()->Resolve());
+
+  EXPECT_EQ(r()->error(),
+            "12:34 error: cannot assign 'vec4<f32>' to 'mat4x4<f32>'");
+}
+
+TEST_F(ResolverCompoundAssignmentValidationTest, Phony) {
+  // {
+  //   _ += 1;
+  // }
+  WrapInFunction(
+      CompoundAssign(Source{{56, 78}}, Phony(), 1, ast::BinaryOp::kAdd));
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "56:78 error: compound assignment operand types are invalid: void "
+            "add i32");
+}
+
+TEST_F(ResolverCompoundAssignmentValidationTest, ReadOnlyBuffer) {
+  // @group(0) @binding(0) var<storage,read> a : i32;
+  // {
+  //   a += 1;
+  // }
+  Global(Source{{12, 34}}, "a", ty.i32(), ast::StorageClass::kStorage,
+         ast::Access::kRead, GroupAndBinding(0, 0));
+  WrapInFunction(CompoundAssign(Source{{56, 78}}, "a", 1, ast::BinaryOp::kAdd));
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "56:78 error: cannot store into a read-only type 'ref<storage, "
+            "i32, read>'");
+}
+
+TEST_F(ResolverCompoundAssignmentValidationTest, LhsConstant) {
+  // let a = 1;
+  // a += 1;
+  auto* a = Const(Source{{12, 34}}, "a", nullptr, Expr(1));
+  WrapInFunction(
+      a, CompoundAssign(Expr(Source{{56, 78}}, "a"), 1, ast::BinaryOp::kAdd));
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(), R"(56:78 error: cannot assign to const
+12:34 note: 'a' is declared here:)");
+}
+
+TEST_F(ResolverCompoundAssignmentValidationTest, LhsLiteral) {
+  // 1 += 1;
+  WrapInFunction(
+      CompoundAssign(Expr(Source{{56, 78}}, 1), 1, ast::BinaryOp::kAdd));
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(), "56:78 error: cannot assign to value of type 'i32'");
+}
+
+TEST_F(ResolverCompoundAssignmentValidationTest, LhsAtomic) {
+  // var<workgroup> a : atomic<i32>;
+  // a += a;
+  Global(Source{{12, 34}}, "a", ty.atomic(ty.i32()),
+         ast::StorageClass::kWorkgroup);
+  WrapInFunction(
+      CompoundAssign(Source{{56, 78}}, "a", "a", ast::BinaryOp::kAdd));
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "56:78 error: compound assignment operand types are invalid: "
+            "atomic<i32> add atomic<i32>");
+}
+
+}  // namespace
+}  // namespace resolver
+}  // namespace tint
diff --git a/src/tint/resolver/dependency_graph.cc b/src/tint/resolver/dependency_graph.cc
index 8d6ef88..8a0db8b 100644
--- a/src/tint/resolver/dependency_graph.cc
+++ b/src/tint/resolver/dependency_graph.cc
@@ -222,6 +222,10 @@
         [&](const ast::CallStatement* r) {  //
           TraverseExpression(r->expr);
         },
+        [&](const ast::CompoundAssignmentStatement* a) {
+          TraverseExpression(a->lhs);
+          TraverseExpression(a->rhs);
+        },
         [&](const ast::ForLoopStatement* l) {
           scope_stack_.Push();
           TINT_DEFER(scope_stack_.Pop());
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index e10425a..f6797b1 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -860,6 +860,9 @@
       [&](const ast::AssignmentStatement* a) { return AssignmentStatement(a); },
       [&](const ast::BreakStatement* b) { return BreakStatement(b); },
       [&](const ast::CallStatement* c) { return CallStatement(c); },
+      [&](const ast::CompoundAssignmentStatement* c) {
+        return CompoundAssignmentStatement(c);
+      },
       [&](const ast::ContinueStatement* c) { return ContinueStatement(c); },
       [&](const ast::DiscardStatement* d) { return DiscardStatement(d); },
       [&](const ast::FallthroughStatement* f) {
@@ -2584,7 +2587,7 @@
       behaviors.Add(lhs->Behaviors());
     }
 
-    return ValidateAssignment(stmt);
+    return ValidateAssignment(stmt, TypeOf(stmt->rhs));
   });
 }
 
@@ -2610,6 +2613,37 @@
   });
 }
 
+sem::Statement* Resolver::CompoundAssignmentStatement(
+    const ast::CompoundAssignmentStatement* stmt) {
+  auto* sem = builder_->create<sem::Statement>(
+      stmt, current_compound_statement_, current_function_);
+  return StatementScope(stmt, sem, [&] {
+    auto* lhs = Expression(stmt->lhs);
+    if (!lhs) {
+      return false;
+    }
+
+    auto* rhs = Expression(stmt->rhs);
+    if (!rhs) {
+      return false;
+    }
+
+    sem->Behaviors() = rhs->Behaviors() + lhs->Behaviors();
+
+    auto* lhs_ty = lhs->Type()->UnwrapRef();
+    auto* rhs_ty = rhs->Type()->UnwrapRef();
+    auto* ty = BinaryOpType(lhs_ty, rhs_ty, stmt->op);
+    if (!ty) {
+      AddError("compound assignment operand types are invalid: " +
+                   TypeNameOf(lhs_ty) + " " + FriendlyName(stmt->op) + " " +
+                   TypeNameOf(rhs_ty),
+               stmt->source);
+      return false;
+    }
+    return ValidateAssignment(stmt, ty);
+  });
+}
+
 sem::Statement* Resolver::ContinueStatement(
     const ast::ContinueStatement* stmt) {
   auto* sem = builder_->create<sem::Statement>(
diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h
index 7c3d217..d5d1632 100644
--- a/src/tint/resolver/resolver.h
+++ b/src/tint/resolver/resolver.h
@@ -214,6 +214,8 @@
   sem::Statement* BreakStatement(const ast::BreakStatement*);
   sem::Statement* CallStatement(const ast::CallStatement*);
   sem::CaseStatement* CaseStatement(const ast::CaseStatement*);
+  sem::Statement* CompoundAssignmentStatement(
+      const ast::CompoundAssignmentStatement*);
   sem::Statement* ContinueStatement(const ast::ContinueStatement*);
   sem::Statement* DiscardStatement(const ast::DiscardStatement*);
   sem::ElseStatement* ElseStatement(const ast::ElseStatement*);
@@ -245,7 +247,7 @@
                                     const Source& source);
   bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s);
   bool ValidateAtomicVariable(const sem::Variable* var);
-  bool ValidateAssignment(const ast::AssignmentStatement* a);
+  bool ValidateAssignment(const ast::Statement* a, const sem::Type* rhs_ty);
   bool ValidateBitcast(const ast::BitcastExpression* cast, const sem::Type* to);
   bool ValidateBreakStatement(const sem::Statement* stmt);
   bool ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr,
diff --git a/src/tint/resolver/resolver_validation.cc b/src/tint/resolver/resolver_validation.cc
index 8dcd207..57cfb47 100644
--- a/src/tint/resolver/resolver_validation.cc
+++ b/src/tint/resolver/resolver_validation.cc
@@ -2264,10 +2264,22 @@
   return true;
 }
 
-bool Resolver::ValidateAssignment(const ast::AssignmentStatement* a) {
-  auto const* rhs_ty = TypeOf(a->rhs);
+bool Resolver::ValidateAssignment(const ast::Statement* a,
+                                  const sem::Type* rhs_ty) {
+  const ast::Expression* lhs;
+  const ast::Expression* rhs;
+  if (auto* assign = a->As<ast::AssignmentStatement>()) {
+    lhs = assign->lhs;
+    rhs = assign->rhs;
+  } else if (auto* compound = a->As<ast::CompoundAssignmentStatement>()) {
+    lhs = compound->lhs;
+    rhs = compound->rhs;
+  } else {
+    TINT_ICE(Resolver, diagnostics_) << "invalid assignment statement";
+    return false;
+  }
 
-  if (a->lhs->Is<ast::PhonyExpression>()) {
+  if (lhs->Is<ast::PhonyExpression>()) {
     // https://www.w3.org/TR/WGSL/#phony-assignment-section
     auto* ty = rhs_ty->UnwrapRef();
     if (!ty->IsConstructible() &&
@@ -2276,26 +2288,26 @@
           "cannot assign '" + TypeNameOf(rhs_ty) +
               "' to '_'. '_' can only be assigned a constructible, pointer, "
               "texture or sampler type",
-          a->rhs->source);
+          rhs->source);
       return false;
     }
     return true;  // RHS can be anything.
   }
 
   // https://gpuweb.github.io/gpuweb/wgsl/#assignment-statement
-  auto const* lhs_ty = TypeOf(a->lhs);
+  auto const* lhs_ty = TypeOf(lhs);
 
-  if (auto* var = ResolvedSymbol<sem::Variable>(a->lhs)) {
+  if (auto* var = ResolvedSymbol<sem::Variable>(lhs)) {
     auto* decl = var->Declaration();
     if (var->Is<sem::Parameter>()) {
-      AddError("cannot assign to function parameter", a->lhs->source);
+      AddError("cannot assign to function parameter", lhs->source);
       AddNote("'" + builder_->Symbols().NameFor(decl->symbol) +
                   "' is declared here:",
               decl->source);
       return false;
     }
     if (decl->is_const) {
-      AddError("cannot assign to const", a->lhs->source);
+      AddError("cannot assign to const", lhs->source);
       AddNote("'" + builder_->Symbols().NameFor(decl->symbol) +
                   "' is declared here:",
               decl->source);
@@ -2307,7 +2319,7 @@
   if (!lhs_ref) {
     // LHS is not a reference, so it has no storage.
     AddError("cannot assign to value of type '" + TypeNameOf(lhs_ty) + "'",
-             a->lhs->source);
+             lhs->source);
     return false;
   }
 
diff --git a/test/tint/BUILD.gn b/test/tint/BUILD.gn
index b0edf0f..8d193da 100644
--- a/test/tint/BUILD.gn
+++ b/test/tint/BUILD.gn
@@ -245,6 +245,7 @@
     "../../src/tint/resolver/builtins_validation_test.cc",
     "../../src/tint/resolver/call_test.cc",
     "../../src/tint/resolver/call_validation_test.cc",
+    "../../src/tint/resolver/compound_assignment_validation_test.cc",
     "../../src/tint/resolver/compound_statement_test.cc",
     "../../src/tint/resolver/control_block_validation_test.cc",
     "../../src/tint/resolver/dependency_graph_test.cc",