resolver: Refactor binary operator type resolution
This same logic will be used for resolving and validating compound
assignment statements, so pull the core out into a separate function
that decouples it from ast::BinaryExpression.
Bug: tint:1325
Change-Id: Ibdb5a7fc8153dac0dd7f9ae3d5164e23585068cd
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/74360
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/ast/binary_expression.h b/src/tint/ast/binary_expression.h
index 7f0f5e7..bcacdd4 100644
--- a/src/tint/ast/binary_expression.h
+++ b/src/tint/ast/binary_expression.h
@@ -122,7 +122,9 @@
const Expression* const rhs;
};
-inline bool BinaryExpression::IsArithmetic() const {
+/// @param op the operator
+/// @returns true if the op is an arithmetic operation
+inline bool IsArithmetic(BinaryOp op) {
switch (op) {
case ast::BinaryOp::kAdd:
case ast::BinaryOp::kSubtract:
@@ -135,7 +137,9 @@
}
}
-inline bool BinaryExpression::IsComparison() const {
+/// @param op the operator
+/// @returns true if the op is a comparison operation
+inline bool IsComparison(BinaryOp op) {
switch (op) {
case ast::BinaryOp::kEqual:
case ast::BinaryOp::kNotEqual:
@@ -149,7 +153,9 @@
}
}
-inline bool BinaryExpression::IsBitwise() const {
+/// @param op the operator
+/// @returns true if the op is a bitwise operation
+inline bool IsBitwise(BinaryOp op) {
switch (op) {
case ast::BinaryOp::kAnd:
case ast::BinaryOp::kOr:
@@ -160,7 +166,9 @@
}
}
-inline bool BinaryExpression::IsBitshift() const {
+/// @param op the operator
+/// @returns true if the op is a bit shift operation
+inline bool IsBitshift(BinaryOp op) {
switch (op) {
case ast::BinaryOp::kShiftLeft:
case ast::BinaryOp::kShiftRight:
@@ -180,6 +188,22 @@
}
}
+inline bool BinaryExpression::IsArithmetic() const {
+ return ast::IsArithmetic(op);
+}
+
+inline bool BinaryExpression::IsComparison() const {
+ return ast::IsComparison(op);
+}
+
+inline bool BinaryExpression::IsBitwise() const {
+ return ast::IsBitwise(op);
+}
+
+inline bool BinaryExpression::IsBitshift() const {
+ return ast::IsBitshift(op);
+}
+
/// @returns the human readable name of the given BinaryOp
/// @param op the BinaryOp
constexpr const char* FriendlyName(BinaryOp op) {
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 6068a73..e10425a 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -1838,6 +1838,33 @@
}
sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
+ auto* lhs = Sem(expr->lhs);
+ auto* rhs = Sem(expr->rhs);
+ auto* lhs_ty = lhs->Type()->UnwrapRef();
+ auto* rhs_ty = rhs->Type()->UnwrapRef();
+
+ auto* ty = BinaryOpType(lhs_ty, rhs_ty, expr->op);
+ if (!ty) {
+ AddError(
+ "Binary expression operand types are invalid for this operation: " +
+ TypeNameOf(lhs_ty) + " " + FriendlyName(expr->op) + " " +
+ TypeNameOf(rhs_ty),
+ expr->source);
+ return nullptr;
+ }
+
+ auto val = EvaluateConstantValue(expr, ty);
+ bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects();
+ auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_,
+ val, has_side_effects);
+ sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors();
+
+ return sem;
+}
+
+const sem::Type* Resolver::BinaryOpType(const sem::Type* lhs_ty,
+ const sem::Type* rhs_ty,
+ ast::BinaryOp op) {
using Bool = sem::Bool;
using F32 = sem::F32;
using I32 = sem::I32;
@@ -1845,12 +1872,6 @@
using Matrix = sem::Matrix;
using Vector = sem::Vector;
- auto* lhs = Sem(expr->lhs);
- auto* rhs = Sem(expr->rhs);
-
- auto* lhs_ty = lhs->Type()->UnwrapRef();
- auto* rhs_ty = rhs->Type()->UnwrapRef();
-
auto* lhs_vec = lhs_ty->As<Vector>();
auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
auto* rhs_vec = rhs_ty->As<Vector>();
@@ -1863,51 +1884,42 @@
const bool matching_types = matching_vec_elem_types || (lhs_ty == rhs_ty);
- auto build = [&](const sem::Type* ty) {
- auto val = EvaluateConstantValue(expr, ty);
- bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects();
- auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_,
- val, has_side_effects);
- sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors();
- return sem;
- };
-
// Binary logical expressions
- if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
+ if (op == ast::BinaryOp::kLogicalAnd || op == ast::BinaryOp::kLogicalOr) {
if (matching_types && lhs_ty->Is<Bool>()) {
- return build(lhs_ty);
+ return lhs_ty;
}
}
- if (expr->IsOr() || expr->IsAnd()) {
+ if (op == ast::BinaryOp::kOr || op == ast::BinaryOp::kAnd) {
if (matching_types && lhs_ty->Is<Bool>()) {
- return build(lhs_ty);
+ return lhs_ty;
}
if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->Is<Bool>()) {
- return build(lhs_ty);
+ return lhs_ty;
}
}
// Arithmetic expressions
- if (expr->IsArithmetic()) {
+ if (ast::IsArithmetic(op)) {
// Binary arithmetic expressions over scalars
if (matching_types && lhs_ty->is_numeric_scalar()) {
- return build(lhs_ty);
+ return lhs_ty;
}
// Binary arithmetic expressions over vectors
if (matching_types && lhs_vec_elem_type &&
lhs_vec_elem_type->is_numeric_scalar()) {
- return build(lhs_ty);
+ return lhs_ty;
}
// Binary arithmetic expressions with mixed scalar and vector operands
if (lhs_vec_elem_type && (lhs_vec_elem_type == rhs_ty) &&
rhs_ty->is_numeric_scalar()) {
- return build(lhs_ty);
+ return lhs_ty;
}
if (rhs_vec_elem_type && (rhs_vec_elem_type == lhs_ty) &&
lhs_ty->is_numeric_scalar()) {
- return build(rhs_ty);
+ return rhs_ty;
}
}
@@ -1917,106 +1929,101 @@
auto* rhs_mat = rhs_ty->As<Matrix>();
auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
// Addition and subtraction of float matrices
- if ((expr->IsAdd() || expr->IsSubtract()) && lhs_mat_elem_type &&
- lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type &&
+ if ((op == ast::BinaryOp::kAdd || op == ast::BinaryOp::kSubtract) &&
+ lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type &&
rhs_mat_elem_type->Is<F32>() &&
(lhs_mat->columns() == rhs_mat->columns()) &&
(lhs_mat->rows() == rhs_mat->rows())) {
- return build(rhs_ty);
+ return rhs_ty;
}
- if (expr->IsMultiply()) {
+ if (op == ast::BinaryOp::kMultiply) {
// Multiplication of a matrix and a scalar
if (lhs_ty->Is<F32>() && rhs_mat_elem_type &&
rhs_mat_elem_type->Is<F32>()) {
- return build(rhs_ty);
+ return rhs_ty;
}
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
rhs_ty->Is<F32>()) {
- return build(lhs_ty);
+ return lhs_ty;
}
// Vector times matrix
if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
(lhs_vec->Width() == rhs_mat->rows())) {
- return build(
- builder_->create<sem::Vector>(lhs_vec->type(), rhs_mat->columns()));
+ return builder_->create<sem::Vector>(lhs_vec->type(), rhs_mat->columns());
}
// Matrix times vector
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>() &&
(lhs_mat->columns() == rhs_vec->Width())) {
- return build(
- builder_->create<sem::Vector>(rhs_vec->type(), lhs_mat->rows()));
+ return builder_->create<sem::Vector>(rhs_vec->type(), lhs_mat->rows());
}
// Matrix times matrix
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
(lhs_mat->columns() == rhs_mat->rows())) {
- return build(builder_->create<sem::Matrix>(
+ return builder_->create<sem::Matrix>(
builder_->create<sem::Vector>(lhs_mat_elem_type, lhs_mat->rows()),
- rhs_mat->columns()));
+ rhs_mat->columns());
}
}
// Comparison expressions
- if (expr->IsComparison()) {
+ if (ast::IsComparison(op)) {
if (matching_types) {
// Special case for bools: only == and !=
- if (lhs_ty->Is<Bool>() && (expr->IsEqual() || expr->IsNotEqual())) {
- return build(builder_->create<sem::Bool>());
+ if (lhs_ty->Is<Bool>() &&
+ (op == ast::BinaryOp::kEqual || op == ast::BinaryOp::kNotEqual)) {
+ return builder_->create<sem::Bool>();
}
// For the rest, we can compare i32, u32, and f32
if (lhs_ty->IsAnyOf<I32, U32, F32>()) {
- return build(builder_->create<sem::Bool>());
+ return builder_->create<sem::Bool>();
}
}
// Same for vectors
if (matching_vec_elem_types) {
if (lhs_vec_elem_type->Is<Bool>() &&
- (expr->IsEqual() || expr->IsNotEqual())) {
- return build(builder_->create<sem::Vector>(
- builder_->create<sem::Bool>(), lhs_vec->Width()));
+ (op == ast::BinaryOp::kEqual || op == ast::BinaryOp::kNotEqual)) {
+ return builder_->create<sem::Vector>(builder_->create<sem::Bool>(),
+ lhs_vec->Width());
}
if (lhs_vec_elem_type->is_numeric_scalar()) {
- return build(builder_->create<sem::Vector>(
- builder_->create<sem::Bool>(), lhs_vec->Width()));
+ return builder_->create<sem::Vector>(builder_->create<sem::Bool>(),
+ lhs_vec->Width());
}
}
}
// Binary bitwise operations
- if (expr->IsBitwise()) {
+ if (ast::IsBitwise(op)) {
if (matching_types && lhs_ty->is_integer_scalar_or_vector()) {
- return build(lhs_ty);
+ return lhs_ty;
}
}
// Bit shift expressions
- if (expr->IsBitshift()) {
+ if (ast::IsBitshift(op)) {
// Type validation rules are the same for left or right shift, despite
// differences in computation rules (i.e. right shift can be arithmetic or
// logical depending on lhs type).
if (lhs_ty->IsAnyOf<I32, U32>() && rhs_ty->Is<U32>()) {
- return build(lhs_ty);
+ return lhs_ty;
}
if (lhs_vec_elem_type && lhs_vec_elem_type->IsAnyOf<I32, U32>() &&
rhs_vec_elem_type && rhs_vec_elem_type->Is<U32>()) {
- return build(lhs_ty);
+ return lhs_ty;
}
}
- AddError("Binary expression operand types are invalid for this operation: " +
- TypeNameOf(lhs_ty) + " " + FriendlyName(expr->op) + " " +
- TypeNameOf(rhs_ty),
- expr->source);
return nullptr;
}
diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h
index fe7e865..7c3d217 100644
--- a/src/tint/resolver/resolver.h
+++ b/src/tint/resolver/resolver.h
@@ -229,6 +229,12 @@
sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*);
bool Statements(const ast::StatementList&);
+ // Resolve the result type of a binary operator.
+ // Returns nullptr if the types are not valid for this operator.
+ const sem::Type* BinaryOpType(const sem::Type* lhs_ty,
+ const sem::Type* rhs_ty,
+ ast::BinaryOp op);
+
// AST and Type validation methods
// Each return true on success, false on failure.
bool ValidateAlias(const ast::Alias*);