[tint][ir][FromProgram] Refactor inc/dec/compound assignment

Consolidate the common logic.

Change-Id: I29f486095ad1492d5aeb57ad1fc6a6ba316d627d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/139923
Kokoro: Kokoro <noreply+kokoro@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/ir/from_program.cc b/src/tint/ir/from_program.cc
index c1bd2fb..2835d9d 100644
--- a/src/tint/ir/from_program.cc
+++ b/src/tint/ir/from_program.cc
@@ -491,99 +491,39 @@
     }
 
     void EmitIncrementDecrement(const ast::IncrementDecrementStatement* stmt) {
-        auto lhs = EmitExpression(stmt->lhs);
-        if (!lhs) {
-            return;
-        }
+        auto* one = program_->TypeOf(stmt->lhs)->UnwrapRef()->is_signed_integer_scalar()
+                        ? builder_.Constant(1_i)
+                        : builder_.Constant(1_u);
+        auto emit_rhs = [one] { return one; };
 
-        // Load from the LHS.
-        auto* lhs_value = builder_.Load(lhs.Get());
-        current_block_->Append(lhs_value);
-
-        auto* ty = lhs_value->Result()->Type();
-
-        auto* rhs =
-            ty->is_signed_integer_scalar() ? builder_.Constant(1_i) : builder_.Constant(1_u);
-
-        Binary* inst = nullptr;
-        if (stmt->increment) {
-            inst = builder_.Add(ty, lhs_value, rhs);
-        } else {
-            inst = builder_.Subtract(ty, lhs_value, rhs);
-        }
-        current_block_->Append(inst);
-
-        auto store = builder_.Store(lhs.Get(), inst);
-        current_block_->Append(store);
+        EmitCompoundAssignment(stmt->lhs, emit_rhs,
+                               stmt->increment ? ast::BinaryOp::kAdd : ast::BinaryOp::kSubtract);
     }
 
     void EmitCompoundAssignment(const ast::CompoundAssignmentStatement* stmt) {
-        auto lhs = EmitExpression(stmt->lhs);
+        auto emit_rhs = [this, stmt] {
+            auto rhs = EmitExpression(stmt->rhs);
+            return rhs ? rhs.Get() : nullptr;
+        };
+        EmitCompoundAssignment(stmt->lhs, emit_rhs, stmt->op);
+    }
+
+    template <typename EMIT_RHS>
+    void EmitCompoundAssignment(const ast::Expression* lhs_expr,
+                                EMIT_RHS&& emit_rhs,
+                                ast::BinaryOp op) {
+        auto lhs = EmitExpression(lhs_expr);
         if (!lhs) {
             return;
         }
-
-        auto rhs = EmitExpression(stmt->rhs);
+        auto rhs = emit_rhs();
         if (!rhs) {
             return;
         }
-
-        // Load from the LHS.
-        auto* lhs_value = builder_.Load(lhs.Get());
-        current_block_->Append(lhs_value);
-
-        auto* ty = lhs_value->Result()->Type();
-
-        Binary* inst = nullptr;
-        switch (stmt->op) {
-            case ast::BinaryOp::kAnd:
-                inst = builder_.And(ty, lhs_value, rhs.Get());
-                break;
-            case ast::BinaryOp::kOr:
-                inst = builder_.Or(ty, lhs_value, rhs.Get());
-                break;
-            case ast::BinaryOp::kXor:
-                inst = builder_.Xor(ty, lhs_value, rhs.Get());
-                break;
-            case ast::BinaryOp::kShiftLeft:
-                inst = builder_.ShiftLeft(ty, lhs_value, rhs.Get());
-                break;
-            case ast::BinaryOp::kShiftRight:
-                inst = builder_.ShiftRight(ty, lhs_value, rhs.Get());
-                break;
-            case ast::BinaryOp::kAdd:
-                inst = builder_.Add(ty, lhs_value, rhs.Get());
-                break;
-            case ast::BinaryOp::kSubtract:
-                inst = builder_.Subtract(ty, lhs_value, rhs.Get());
-                break;
-            case ast::BinaryOp::kMultiply:
-                inst = builder_.Multiply(ty, lhs_value, rhs.Get());
-                break;
-            case ast::BinaryOp::kDivide:
-                inst = builder_.Divide(ty, lhs_value, rhs.Get());
-                break;
-            case ast::BinaryOp::kModulo:
-                inst = builder_.Modulo(ty, lhs_value, rhs.Get());
-                break;
-            case ast::BinaryOp::kLessThanEqual:
-            case ast::BinaryOp::kGreaterThanEqual:
-            case ast::BinaryOp::kGreaterThan:
-            case ast::BinaryOp::kLessThan:
-            case ast::BinaryOp::kNotEqual:
-            case ast::BinaryOp::kEqual:
-            case ast::BinaryOp::kLogicalAnd:
-            case ast::BinaryOp::kLogicalOr:
-                TINT_ICE(IR, diagnostics_) << "invalid compound assignment";
-                return;
-            case ast::BinaryOp::kNone:
-                TINT_ICE(IR, diagnostics_) << "missing binary operand type";
-                return;
-        }
-        current_block_->Append(inst);
-
-        auto store = builder_.Store(lhs.Get(), inst);
-        current_block_->Append(store);
+        auto* load = current_block_->Append(builder_.Load(lhs.Get()));
+        auto* ty = load->Result()->Type();
+        auto* inst = current_block_->Append(BinaryOp(ty, load->Result(), rhs, op));
+        current_block_->Append(builder_.Store(lhs.Get(), inst));
     }
 
     void EmitBlock(const ast::BlockStatement* block) {
@@ -1240,63 +1180,9 @@
         auto* sem = program_->Sem().Get(expr);
         auto* ty = sem->Type()->Clone(clone_ctx_.type_ctx);
 
-        Binary* inst = nullptr;
-        switch (expr->op) {
-            case ast::BinaryOp::kAnd:
-                inst = builder_.And(ty, lhs.Get(), rhs.Get());
-                break;
-            case ast::BinaryOp::kOr:
-                inst = builder_.Or(ty, lhs.Get(), rhs.Get());
-                break;
-            case ast::BinaryOp::kXor:
-                inst = builder_.Xor(ty, lhs.Get(), rhs.Get());
-                break;
-            case ast::BinaryOp::kEqual:
-                inst = builder_.Equal(ty, lhs.Get(), rhs.Get());
-                break;
-            case ast::BinaryOp::kNotEqual:
-                inst = builder_.NotEqual(ty, lhs.Get(), rhs.Get());
-                break;
-            case ast::BinaryOp::kLessThan:
-                inst = builder_.LessThan(ty, lhs.Get(), rhs.Get());
-                break;
-            case ast::BinaryOp::kGreaterThan:
-                inst = builder_.GreaterThan(ty, lhs.Get(), rhs.Get());
-                break;
-            case ast::BinaryOp::kLessThanEqual:
-                inst = builder_.LessThanEqual(ty, lhs.Get(), rhs.Get());
-                break;
-            case ast::BinaryOp::kGreaterThanEqual:
-                inst = builder_.GreaterThanEqual(ty, lhs.Get(), rhs.Get());
-                break;
-            case ast::BinaryOp::kShiftLeft:
-                inst = builder_.ShiftLeft(ty, lhs.Get(), rhs.Get());
-                break;
-            case ast::BinaryOp::kShiftRight:
-                inst = builder_.ShiftRight(ty, lhs.Get(), rhs.Get());
-                break;
-            case ast::BinaryOp::kAdd:
-                inst = builder_.Add(ty, lhs.Get(), rhs.Get());
-                break;
-            case ast::BinaryOp::kSubtract:
-                inst = builder_.Subtract(ty, lhs.Get(), rhs.Get());
-                break;
-            case ast::BinaryOp::kMultiply:
-                inst = builder_.Multiply(ty, lhs.Get(), rhs.Get());
-                break;
-            case ast::BinaryOp::kDivide:
-                inst = builder_.Divide(ty, lhs.Get(), rhs.Get());
-                break;
-            case ast::BinaryOp::kModulo:
-                inst = builder_.Modulo(ty, lhs.Get(), rhs.Get());
-                break;
-            case ast::BinaryOp::kLogicalAnd:
-            case ast::BinaryOp::kLogicalOr:
-                TINT_ICE(IR, diagnostics_) << "short circuit op should have already been handled";
-                return utils::Failure;
-            case ast::BinaryOp::kNone:
-                TINT_ICE(IR, diagnostics_) << "missing binary operand type";
-                return utils::Failure;
+        Binary* inst = BinaryOp(ty, lhs.Get(), rhs.Get(), expr->op);
+        if (!inst) {
+            return utils::Failure;
         }
 
         current_block_->Append(inst);
@@ -1396,6 +1282,52 @@
         }
         return builder_.Constant(cv);
     }
+
+    ir::Binary* BinaryOp(const type::Type* ty, ir::Value* lhs, ir::Value* rhs, ast::BinaryOp op) {
+        switch (op) {
+            case ast::BinaryOp::kAnd:
+                return builder_.And(ty, lhs, rhs);
+            case ast::BinaryOp::kOr:
+                return builder_.Or(ty, lhs, rhs);
+            case ast::BinaryOp::kXor:
+                return builder_.Xor(ty, lhs, rhs);
+            case ast::BinaryOp::kEqual:
+                return builder_.Equal(ty, lhs, rhs);
+            case ast::BinaryOp::kNotEqual:
+                return builder_.NotEqual(ty, lhs, rhs);
+            case ast::BinaryOp::kLessThan:
+                return builder_.LessThan(ty, lhs, rhs);
+            case ast::BinaryOp::kGreaterThan:
+                return builder_.GreaterThan(ty, lhs, rhs);
+            case ast::BinaryOp::kLessThanEqual:
+                return builder_.LessThanEqual(ty, lhs, rhs);
+            case ast::BinaryOp::kGreaterThanEqual:
+                return builder_.GreaterThanEqual(ty, lhs, rhs);
+            case ast::BinaryOp::kShiftLeft:
+                return builder_.ShiftLeft(ty, lhs, rhs);
+            case ast::BinaryOp::kShiftRight:
+                return builder_.ShiftRight(ty, lhs, rhs);
+            case ast::BinaryOp::kAdd:
+                return builder_.Add(ty, lhs, rhs);
+            case ast::BinaryOp::kSubtract:
+                return builder_.Subtract(ty, lhs, rhs);
+            case ast::BinaryOp::kMultiply:
+                return builder_.Multiply(ty, lhs, rhs);
+            case ast::BinaryOp::kDivide:
+                return builder_.Divide(ty, lhs, rhs);
+            case ast::BinaryOp::kModulo:
+                return builder_.Modulo(ty, lhs, rhs);
+            case ast::BinaryOp::kLogicalAnd:
+            case ast::BinaryOp::kLogicalOr:
+                TINT_ICE(IR, diagnostics_) << "short circuit op should have already been handled";
+                return nullptr;
+            case ast::BinaryOp::kNone:
+                TINT_ICE(IR, diagnostics_) << "missing binary operand type";
+                return nullptr;
+        }
+        TINT_UNREACHABLE(IR, diagnostics_);
+        return nullptr;
+    }
 };
 
 }  // namespace