[tint][ir][FromProgram] Refactor inc/dec/compound assignment
Consolidate the common logic.
Change-Id: I29f486095ad1492d5aeb57ad1fc6a6ba316d627d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/139923
Kokoro: Kokoro <noreply+kokoro@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/ir/from_program.cc b/src/tint/ir/from_program.cc
index c1bd2fb..2835d9d 100644
--- a/src/tint/ir/from_program.cc
+++ b/src/tint/ir/from_program.cc
@@ -491,99 +491,39 @@
}
void EmitIncrementDecrement(const ast::IncrementDecrementStatement* stmt) {
- auto lhs = EmitExpression(stmt->lhs);
- if (!lhs) {
- return;
- }
+ auto* one = program_->TypeOf(stmt->lhs)->UnwrapRef()->is_signed_integer_scalar()
+ ? builder_.Constant(1_i)
+ : builder_.Constant(1_u);
+ auto emit_rhs = [one] { return one; };
- // Load from the LHS.
- auto* lhs_value = builder_.Load(lhs.Get());
- current_block_->Append(lhs_value);
-
- auto* ty = lhs_value->Result()->Type();
-
- auto* rhs =
- ty->is_signed_integer_scalar() ? builder_.Constant(1_i) : builder_.Constant(1_u);
-
- Binary* inst = nullptr;
- if (stmt->increment) {
- inst = builder_.Add(ty, lhs_value, rhs);
- } else {
- inst = builder_.Subtract(ty, lhs_value, rhs);
- }
- current_block_->Append(inst);
-
- auto store = builder_.Store(lhs.Get(), inst);
- current_block_->Append(store);
+ EmitCompoundAssignment(stmt->lhs, emit_rhs,
+ stmt->increment ? ast::BinaryOp::kAdd : ast::BinaryOp::kSubtract);
}
void EmitCompoundAssignment(const ast::CompoundAssignmentStatement* stmt) {
- auto lhs = EmitExpression(stmt->lhs);
+ auto emit_rhs = [this, stmt] {
+ auto rhs = EmitExpression(stmt->rhs);
+ return rhs ? rhs.Get() : nullptr;
+ };
+ EmitCompoundAssignment(stmt->lhs, emit_rhs, stmt->op);
+ }
+
+ template <typename EMIT_RHS>
+ void EmitCompoundAssignment(const ast::Expression* lhs_expr,
+ EMIT_RHS&& emit_rhs,
+ ast::BinaryOp op) {
+ auto lhs = EmitExpression(lhs_expr);
if (!lhs) {
return;
}
-
- auto rhs = EmitExpression(stmt->rhs);
+ auto rhs = emit_rhs();
if (!rhs) {
return;
}
-
- // Load from the LHS.
- auto* lhs_value = builder_.Load(lhs.Get());
- current_block_->Append(lhs_value);
-
- auto* ty = lhs_value->Result()->Type();
-
- Binary* inst = nullptr;
- switch (stmt->op) {
- case ast::BinaryOp::kAnd:
- inst = builder_.And(ty, lhs_value, rhs.Get());
- break;
- case ast::BinaryOp::kOr:
- inst = builder_.Or(ty, lhs_value, rhs.Get());
- break;
- case ast::BinaryOp::kXor:
- inst = builder_.Xor(ty, lhs_value, rhs.Get());
- break;
- case ast::BinaryOp::kShiftLeft:
- inst = builder_.ShiftLeft(ty, lhs_value, rhs.Get());
- break;
- case ast::BinaryOp::kShiftRight:
- inst = builder_.ShiftRight(ty, lhs_value, rhs.Get());
- break;
- case ast::BinaryOp::kAdd:
- inst = builder_.Add(ty, lhs_value, rhs.Get());
- break;
- case ast::BinaryOp::kSubtract:
- inst = builder_.Subtract(ty, lhs_value, rhs.Get());
- break;
- case ast::BinaryOp::kMultiply:
- inst = builder_.Multiply(ty, lhs_value, rhs.Get());
- break;
- case ast::BinaryOp::kDivide:
- inst = builder_.Divide(ty, lhs_value, rhs.Get());
- break;
- case ast::BinaryOp::kModulo:
- inst = builder_.Modulo(ty, lhs_value, rhs.Get());
- break;
- case ast::BinaryOp::kLessThanEqual:
- case ast::BinaryOp::kGreaterThanEqual:
- case ast::BinaryOp::kGreaterThan:
- case ast::BinaryOp::kLessThan:
- case ast::BinaryOp::kNotEqual:
- case ast::BinaryOp::kEqual:
- case ast::BinaryOp::kLogicalAnd:
- case ast::BinaryOp::kLogicalOr:
- TINT_ICE(IR, diagnostics_) << "invalid compound assignment";
- return;
- case ast::BinaryOp::kNone:
- TINT_ICE(IR, diagnostics_) << "missing binary operand type";
- return;
- }
- current_block_->Append(inst);
-
- auto store = builder_.Store(lhs.Get(), inst);
- current_block_->Append(store);
+ auto* load = current_block_->Append(builder_.Load(lhs.Get()));
+ auto* ty = load->Result()->Type();
+ auto* inst = current_block_->Append(BinaryOp(ty, load->Result(), rhs, op));
+ current_block_->Append(builder_.Store(lhs.Get(), inst));
}
void EmitBlock(const ast::BlockStatement* block) {
@@ -1240,63 +1180,9 @@
auto* sem = program_->Sem().Get(expr);
auto* ty = sem->Type()->Clone(clone_ctx_.type_ctx);
- Binary* inst = nullptr;
- switch (expr->op) {
- case ast::BinaryOp::kAnd:
- inst = builder_.And(ty, lhs.Get(), rhs.Get());
- break;
- case ast::BinaryOp::kOr:
- inst = builder_.Or(ty, lhs.Get(), rhs.Get());
- break;
- case ast::BinaryOp::kXor:
- inst = builder_.Xor(ty, lhs.Get(), rhs.Get());
- break;
- case ast::BinaryOp::kEqual:
- inst = builder_.Equal(ty, lhs.Get(), rhs.Get());
- break;
- case ast::BinaryOp::kNotEqual:
- inst = builder_.NotEqual(ty, lhs.Get(), rhs.Get());
- break;
- case ast::BinaryOp::kLessThan:
- inst = builder_.LessThan(ty, lhs.Get(), rhs.Get());
- break;
- case ast::BinaryOp::kGreaterThan:
- inst = builder_.GreaterThan(ty, lhs.Get(), rhs.Get());
- break;
- case ast::BinaryOp::kLessThanEqual:
- inst = builder_.LessThanEqual(ty, lhs.Get(), rhs.Get());
- break;
- case ast::BinaryOp::kGreaterThanEqual:
- inst = builder_.GreaterThanEqual(ty, lhs.Get(), rhs.Get());
- break;
- case ast::BinaryOp::kShiftLeft:
- inst = builder_.ShiftLeft(ty, lhs.Get(), rhs.Get());
- break;
- case ast::BinaryOp::kShiftRight:
- inst = builder_.ShiftRight(ty, lhs.Get(), rhs.Get());
- break;
- case ast::BinaryOp::kAdd:
- inst = builder_.Add(ty, lhs.Get(), rhs.Get());
- break;
- case ast::BinaryOp::kSubtract:
- inst = builder_.Subtract(ty, lhs.Get(), rhs.Get());
- break;
- case ast::BinaryOp::kMultiply:
- inst = builder_.Multiply(ty, lhs.Get(), rhs.Get());
- break;
- case ast::BinaryOp::kDivide:
- inst = builder_.Divide(ty, lhs.Get(), rhs.Get());
- break;
- case ast::BinaryOp::kModulo:
- inst = builder_.Modulo(ty, lhs.Get(), rhs.Get());
- break;
- case ast::BinaryOp::kLogicalAnd:
- case ast::BinaryOp::kLogicalOr:
- TINT_ICE(IR, diagnostics_) << "short circuit op should have already been handled";
- return utils::Failure;
- case ast::BinaryOp::kNone:
- TINT_ICE(IR, diagnostics_) << "missing binary operand type";
- return utils::Failure;
+ Binary* inst = BinaryOp(ty, lhs.Get(), rhs.Get(), expr->op);
+ if (!inst) {
+ return utils::Failure;
}
current_block_->Append(inst);
@@ -1396,6 +1282,52 @@
}
return builder_.Constant(cv);
}
+
+ ir::Binary* BinaryOp(const type::Type* ty, ir::Value* lhs, ir::Value* rhs, ast::BinaryOp op) {
+ switch (op) {
+ case ast::BinaryOp::kAnd:
+ return builder_.And(ty, lhs, rhs);
+ case ast::BinaryOp::kOr:
+ return builder_.Or(ty, lhs, rhs);
+ case ast::BinaryOp::kXor:
+ return builder_.Xor(ty, lhs, rhs);
+ case ast::BinaryOp::kEqual:
+ return builder_.Equal(ty, lhs, rhs);
+ case ast::BinaryOp::kNotEqual:
+ return builder_.NotEqual(ty, lhs, rhs);
+ case ast::BinaryOp::kLessThan:
+ return builder_.LessThan(ty, lhs, rhs);
+ case ast::BinaryOp::kGreaterThan:
+ return builder_.GreaterThan(ty, lhs, rhs);
+ case ast::BinaryOp::kLessThanEqual:
+ return builder_.LessThanEqual(ty, lhs, rhs);
+ case ast::BinaryOp::kGreaterThanEqual:
+ return builder_.GreaterThanEqual(ty, lhs, rhs);
+ case ast::BinaryOp::kShiftLeft:
+ return builder_.ShiftLeft(ty, lhs, rhs);
+ case ast::BinaryOp::kShiftRight:
+ return builder_.ShiftRight(ty, lhs, rhs);
+ case ast::BinaryOp::kAdd:
+ return builder_.Add(ty, lhs, rhs);
+ case ast::BinaryOp::kSubtract:
+ return builder_.Subtract(ty, lhs, rhs);
+ case ast::BinaryOp::kMultiply:
+ return builder_.Multiply(ty, lhs, rhs);
+ case ast::BinaryOp::kDivide:
+ return builder_.Divide(ty, lhs, rhs);
+ case ast::BinaryOp::kModulo:
+ return builder_.Modulo(ty, lhs, rhs);
+ case ast::BinaryOp::kLogicalAnd:
+ case ast::BinaryOp::kLogicalOr:
+ TINT_ICE(IR, diagnostics_) << "short circuit op should have already been handled";
+ return nullptr;
+ case ast::BinaryOp::kNone:
+ TINT_ICE(IR, diagnostics_) << "missing binary operand type";
+ return nullptr;
+ }
+ TINT_UNREACHABLE(IR, diagnostics_);
+ return nullptr;
+ }
};
} // namespace