resolver: Refactor binary operator type resolution This same logic will be used for resolving and validating compound assignment statements, so pull the core out into a separate function that decouples it from ast::BinaryExpression. Bug: tint:1325 Change-Id: Ibdb5a7fc8153dac0dd7f9ae3d5164e23585068cd Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/74360 Reviewed-by: Antonio Maiorano <amaiorano@google.com> Reviewed-by: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/ast/binary_expression.h b/src/tint/ast/binary_expression.h index 7f0f5e7..bcacdd4 100644 --- a/src/tint/ast/binary_expression.h +++ b/src/tint/ast/binary_expression.h
@@ -122,7 +122,9 @@ const Expression* const rhs; }; -inline bool BinaryExpression::IsArithmetic() const { +/// @param op the operator +/// @returns true if the op is an arithmetic operation +inline bool IsArithmetic(BinaryOp op) { switch (op) { case ast::BinaryOp::kAdd: case ast::BinaryOp::kSubtract: @@ -135,7 +137,9 @@ } } -inline bool BinaryExpression::IsComparison() const { +/// @param op the operator +/// @returns true if the op is a comparison operation +inline bool IsComparison(BinaryOp op) { switch (op) { case ast::BinaryOp::kEqual: case ast::BinaryOp::kNotEqual: @@ -149,7 +153,9 @@ } } -inline bool BinaryExpression::IsBitwise() const { +/// @param op the operator +/// @returns true if the op is a bitwise operation +inline bool IsBitwise(BinaryOp op) { switch (op) { case ast::BinaryOp::kAnd: case ast::BinaryOp::kOr: @@ -160,7 +166,9 @@ } } -inline bool BinaryExpression::IsBitshift() const { +/// @param op the operator +/// @returns true if the op is a bit shift operation +inline bool IsBitshift(BinaryOp op) { switch (op) { case ast::BinaryOp::kShiftLeft: case ast::BinaryOp::kShiftRight: @@ -180,6 +188,22 @@ } } +inline bool BinaryExpression::IsArithmetic() const { + return ast::IsArithmetic(op); +} + +inline bool BinaryExpression::IsComparison() const { + return ast::IsComparison(op); +} + +inline bool BinaryExpression::IsBitwise() const { + return ast::IsBitwise(op); +} + +inline bool BinaryExpression::IsBitshift() const { + return ast::IsBitshift(op); +} + /// @returns the human readable name of the given BinaryOp /// @param op the BinaryOp constexpr const char* FriendlyName(BinaryOp op) {
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index 6068a73..e10425a 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc
@@ -1838,6 +1838,33 @@ } sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) { + auto* lhs = Sem(expr->lhs); + auto* rhs = Sem(expr->rhs); + auto* lhs_ty = lhs->Type()->UnwrapRef(); + auto* rhs_ty = rhs->Type()->UnwrapRef(); + + auto* ty = BinaryOpType(lhs_ty, rhs_ty, expr->op); + if (!ty) { + AddError( + "Binary expression operand types are invalid for this operation: " + + TypeNameOf(lhs_ty) + " " + FriendlyName(expr->op) + " " + + TypeNameOf(rhs_ty), + expr->source); + return nullptr; + } + + auto val = EvaluateConstantValue(expr, ty); + bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects(); + auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_, + val, has_side_effects); + sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors(); + + return sem; +} + +const sem::Type* Resolver::BinaryOpType(const sem::Type* lhs_ty, + const sem::Type* rhs_ty, + ast::BinaryOp op) { using Bool = sem::Bool; using F32 = sem::F32; using I32 = sem::I32; @@ -1845,12 +1872,6 @@ using Matrix = sem::Matrix; using Vector = sem::Vector; - auto* lhs = Sem(expr->lhs); - auto* rhs = Sem(expr->rhs); - - auto* lhs_ty = lhs->Type()->UnwrapRef(); - auto* rhs_ty = rhs->Type()->UnwrapRef(); - auto* lhs_vec = lhs_ty->As<Vector>(); auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr; auto* rhs_vec = rhs_ty->As<Vector>(); @@ -1863,51 +1884,42 @@ const bool matching_types = matching_vec_elem_types || (lhs_ty == rhs_ty); - auto build = [&](const sem::Type* ty) { - auto val = EvaluateConstantValue(expr, ty); - bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects(); - auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_, - val, has_side_effects); - sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors(); - return sem; - }; - // Binary logical expressions - if (expr->IsLogicalAnd() || expr->IsLogicalOr()) { + if (op == ast::BinaryOp::kLogicalAnd || op == ast::BinaryOp::kLogicalOr) { if (matching_types && lhs_ty->Is<Bool>()) { - return build(lhs_ty); + return lhs_ty; } } - if (expr->IsOr() || expr->IsAnd()) { + if (op == ast::BinaryOp::kOr || op == ast::BinaryOp::kAnd) { if (matching_types && lhs_ty->Is<Bool>()) { - return build(lhs_ty); + return lhs_ty; } if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->Is<Bool>()) { - return build(lhs_ty); + return lhs_ty; } } // Arithmetic expressions - if (expr->IsArithmetic()) { + if (ast::IsArithmetic(op)) { // Binary arithmetic expressions over scalars if (matching_types && lhs_ty->is_numeric_scalar()) { - return build(lhs_ty); + return lhs_ty; } // Binary arithmetic expressions over vectors if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->is_numeric_scalar()) { - return build(lhs_ty); + return lhs_ty; } // Binary arithmetic expressions with mixed scalar and vector operands if (lhs_vec_elem_type && (lhs_vec_elem_type == rhs_ty) && rhs_ty->is_numeric_scalar()) { - return build(lhs_ty); + return lhs_ty; } if (rhs_vec_elem_type && (rhs_vec_elem_type == lhs_ty) && lhs_ty->is_numeric_scalar()) { - return build(rhs_ty); + return rhs_ty; } } @@ -1917,106 +1929,101 @@ auto* rhs_mat = rhs_ty->As<Matrix>(); auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr; // Addition and subtraction of float matrices - if ((expr->IsAdd() || expr->IsSubtract()) && lhs_mat_elem_type && - lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type && + if ((op == ast::BinaryOp::kAdd || op == ast::BinaryOp::kSubtract) && + lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() && (lhs_mat->columns() == rhs_mat->columns()) && (lhs_mat->rows() == rhs_mat->rows())) { - return build(rhs_ty); + return rhs_ty; } - if (expr->IsMultiply()) { + if (op == ast::BinaryOp::kMultiply) { // Multiplication of a matrix and a scalar if (lhs_ty->Is<F32>() && rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>()) { - return build(rhs_ty); + return rhs_ty; } if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() && rhs_ty->Is<F32>()) { - return build(lhs_ty); + return lhs_ty; } // Vector times matrix if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() && rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() && (lhs_vec->Width() == rhs_mat->rows())) { - return build( - builder_->create<sem::Vector>(lhs_vec->type(), rhs_mat->columns())); + return builder_->create<sem::Vector>(lhs_vec->type(), rhs_mat->columns()); } // Matrix times vector if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() && rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>() && (lhs_mat->columns() == rhs_vec->Width())) { - return build( - builder_->create<sem::Vector>(rhs_vec->type(), lhs_mat->rows())); + return builder_->create<sem::Vector>(rhs_vec->type(), lhs_mat->rows()); } // Matrix times matrix if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() && (lhs_mat->columns() == rhs_mat->rows())) { - return build(builder_->create<sem::Matrix>( + return builder_->create<sem::Matrix>( builder_->create<sem::Vector>(lhs_mat_elem_type, lhs_mat->rows()), - rhs_mat->columns())); + rhs_mat->columns()); } } // Comparison expressions - if (expr->IsComparison()) { + if (ast::IsComparison(op)) { if (matching_types) { // Special case for bools: only == and != - if (lhs_ty->Is<Bool>() && (expr->IsEqual() || expr->IsNotEqual())) { - return build(builder_->create<sem::Bool>()); + if (lhs_ty->Is<Bool>() && + (op == ast::BinaryOp::kEqual || op == ast::BinaryOp::kNotEqual)) { + return builder_->create<sem::Bool>(); } // For the rest, we can compare i32, u32, and f32 if (lhs_ty->IsAnyOf<I32, U32, F32>()) { - return build(builder_->create<sem::Bool>()); + return builder_->create<sem::Bool>(); } } // Same for vectors if (matching_vec_elem_types) { if (lhs_vec_elem_type->Is<Bool>() && - (expr->IsEqual() || expr->IsNotEqual())) { - return build(builder_->create<sem::Vector>( - builder_->create<sem::Bool>(), lhs_vec->Width())); + (op == ast::BinaryOp::kEqual || op == ast::BinaryOp::kNotEqual)) { + return builder_->create<sem::Vector>(builder_->create<sem::Bool>(), + lhs_vec->Width()); } if (lhs_vec_elem_type->is_numeric_scalar()) { - return build(builder_->create<sem::Vector>( - builder_->create<sem::Bool>(), lhs_vec->Width())); + return builder_->create<sem::Vector>(builder_->create<sem::Bool>(), + lhs_vec->Width()); } } } // Binary bitwise operations - if (expr->IsBitwise()) { + if (ast::IsBitwise(op)) { if (matching_types && lhs_ty->is_integer_scalar_or_vector()) { - return build(lhs_ty); + return lhs_ty; } } // Bit shift expressions - if (expr->IsBitshift()) { + if (ast::IsBitshift(op)) { // Type validation rules are the same for left or right shift, despite // differences in computation rules (i.e. right shift can be arithmetic or // logical depending on lhs type). if (lhs_ty->IsAnyOf<I32, U32>() && rhs_ty->Is<U32>()) { - return build(lhs_ty); + return lhs_ty; } if (lhs_vec_elem_type && lhs_vec_elem_type->IsAnyOf<I32, U32>() && rhs_vec_elem_type && rhs_vec_elem_type->Is<U32>()) { - return build(lhs_ty); + return lhs_ty; } } - AddError("Binary expression operand types are invalid for this operation: " + - TypeNameOf(lhs_ty) + " " + FriendlyName(expr->op) + " " + - TypeNameOf(rhs_ty), - expr->source); return nullptr; }
diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index fe7e865..7c3d217 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h
@@ -229,6 +229,12 @@ sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*); bool Statements(const ast::StatementList&); + // Resolve the result type of a binary operator. + // Returns nullptr if the types are not valid for this operator. + const sem::Type* BinaryOpType(const sem::Type* lhs_ty, + const sem::Type* rhs_ty, + ast::BinaryOp op); + // AST and Type validation methods // Each return true on success, false on failure. bool ValidateAlias(const ast::Alias*);