[Tint] Validate compound assign's binary exprs

Add validation of compound assignments as binary expressions. This
resolves a later ICE that will occur for invalid compound assignments
when they are desugaredd to invalid binary expressions.

Fixes: 354070144
Change-Id: Ib2e2d355201d0ca781b903f498e1102892f96540
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/199775
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Natalie Chouinard <chouinard@google.com>
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc
index 039f602..aeba08d 100644
--- a/src/tint/lang/wgsl/resolver/resolver.cc
+++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -3594,7 +3594,7 @@
         }
     }
 
-    if (!validator_.BinaryExpression(expr, rhs, lhs_ty)) {
+    if (!validator_.BinaryExpression(expr, expr->op, lhs, rhs)) {
         return nullptr;
     }
 
diff --git a/src/tint/lang/wgsl/resolver/validation_test.cc b/src/tint/lang/wgsl/resolver/validation_test.cc
index 3d854e5..fa55dd0 100644
--- a/src/tint/lang/wgsl/resolver/validation_test.cc
+++ b/src/tint/lang/wgsl/resolver/validation_test.cc
@@ -1453,6 +1453,25 @@
         R"(1:2 error: shift left value must be less than the bit width of the lhs, which is 32)");
 }
 
+TEST_F(ResolverValidationTest, ShiftLeft_I32_CompoundAssign_Valid) {
+    GlobalVar("v", ty.i32(), core::AddressSpace::kPrivate);
+    auto* expr = CompoundAssign(Source{{1, 2}}, "v", 1_u, core::BinaryOp::kShiftLeft);
+    WrapInFunction(expr);
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverValidationTest, ShiftLeft_I32_CompoundAssign_Invalid) {
+    GlobalVar("v", ty.i32(), core::AddressSpace::kPrivate);
+    auto* expr = CompoundAssign(Source{{1, 2}}, "v", 64_u, core::BinaryOp::kShiftLeft);
+    WrapInFunction(expr);
+
+    EXPECT_FALSE(r()->Resolve());
+    EXPECT_EQ(
+        r()->error(),
+        R"(1:2 error: shift left value must be less than the bit width of the lhs, which is 32)");
+}
+
 }  // namespace
 }  // namespace tint::resolver
 
diff --git a/src/tint/lang/wgsl/resolver/validator.cc b/src/tint/lang/wgsl/resolver/validator.cc
index 014c784..f8b75f1 100644
--- a/src/tint/lang/wgsl/resolver/validator.cc
+++ b/src/tint/lang/wgsl/resolver/validator.cc
@@ -1672,23 +1672,24 @@
     return true;
 }
 
-bool Validator::BinaryExpression(const ast::BinaryExpression* expr,
-                                 const tint::sem::ValueExpression* rhs,
-                                 const tint::core::type::Type* lhs_ty) const {
-    switch (expr->op) {
+bool Validator::BinaryExpression(const ast::Node* node,
+                                 const core::BinaryOp op,
+                                 const tint::sem::ValueExpression* lhs,
+                                 const tint::sem::ValueExpression* rhs) const {
+    switch (op) {
         case core::BinaryOp::kShiftLeft:
-        case core::BinaryOp::kShiftRight:
+        case core::BinaryOp::kShiftRight: {
             // If lhs value is a concrete type, and rhs is a const-expression greater than or equal
             // to the bit width of lhs, then it is a shader-creation error.
-            if (!lhs_ty->HoldsAbstract() && rhs->Stage() == core::EvaluationStage::kConstant) {
-                const uint32_t bit_width = lhs_ty->DeepestElement()->Size() * 8;
+            const auto* elem_type = lhs->Type()->UnwrapRef()->DeepestElement();
+            if (!elem_type->HoldsAbstract() && rhs->Stage() == core::EvaluationStage::kConstant) {
+                const uint32_t bit_width = elem_type->Size() * 8;
                 auto* rhs_val = rhs->ConstantValue();
                 for (size_t i = 0, n = rhs_val->NumElements(); i < n; i++) {
                     auto* shift_val = n == 1 ? rhs_val : rhs_val->Index(i);
                     if (shift_val->ValueAs<u32>() >= bit_width) {
-                        AddError(expr->source)
-                            << "shift "
-                            << (expr->op == core::BinaryOp::kShiftLeft ? "left" : "right")
+                        AddError(node->source)
+                            << "shift " << (op == core::BinaryOp::kShiftLeft ? "left" : "right")
                             << " value must be less than the bit width of the lhs, which is "
                             << bit_width;
                         return false;
@@ -1696,8 +1697,10 @@
                 }
             }
             return true;
-        default:
+        }
+        default: {
             return true;
+        }
     }
 }
 
@@ -2750,6 +2753,9 @@
     } else if (auto* compound = a->As<ast::CompoundAssignmentStatement>()) {
         lhs = compound->lhs;
         rhs = compound->rhs;
+        if (!BinaryExpression(a, compound->op, sem_.GetVal(lhs), sem_.GetVal(rhs))) {
+            return false;
+        }
     } else {
         TINT_ICE() << "invalid assignment statement";
     }
diff --git a/src/tint/lang/wgsl/resolver/validator.h b/src/tint/lang/wgsl/resolver/validator.h
index fb3a18e..463f593 100644
--- a/src/tint/lang/wgsl/resolver/validator.h
+++ b/src/tint/lang/wgsl/resolver/validator.h
@@ -224,13 +224,15 @@
     bool Assignment(const ast::Statement* a, const core::type::Type* rhs_ty) const;
 
     /// Validates a binary expression
-    /// @param expr the ast binary expression
+    /// @param node the ast binary expression or compound assignment node
+    /// @param op the binary operator
+    /// @param lhs the left hand side sem node
     /// @param rhs the right hand side sem node
-    /// @param lhs_ty the type of the left hand side
     /// @returns true on success, false otherwise.
-    bool BinaryExpression(const ast::BinaryExpression* expr,
-                          const tint::sem::ValueExpression* rhs,
-                          const tint::core::type::Type* lhs_ty) const;
+    bool BinaryExpression(const ast::Node* node,
+                          const core::BinaryOp op,
+                          const tint::sem::ValueExpression* lhs,
+                          const tint::sem::ValueExpression* rhs) const;
 
     /// Validates a break statement
     /// @param stmt the break statement to validate