Resolver: merge binary expression validation and resolving logic together This avoid duplicating the logic in two places, and makes it easier to implement according to the spec. Bug: tint:376 Change-Id: If62f508e2c76b5b661e66aae9ff20b8e874a65d8 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/52323 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Ben Clayton <bclayton@google.com> Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 15e7f3e..3be647b 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc
@@ -2035,7 +2035,14 @@ return true; } -bool Resolver::ValidateBinary(ast::BinaryExpression* expr) { +bool Resolver::Binary(ast::BinaryExpression* expr) { + Mark(expr->lhs()); + Mark(expr->rhs()); + + if (!Expression(expr->lhs()) || !Expression(expr->rhs())) { + return false; + } + using Bool = sem::Bool; using F32 = sem::F32; using I32 = sem::I32; @@ -2061,14 +2068,17 @@ // Binary logical expressions if (expr->IsLogicalAnd() || expr->IsLogicalOr()) { if (matching_types && lhs_type->Is<Bool>()) { + SetType(expr, lhs_type); return true; } } if (expr->IsOr() || expr->IsAnd()) { if (matching_types && lhs_type->Is<Bool>()) { + SetType(expr, lhs_type); return true; } if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->Is<Bool>()) { + SetType(expr, lhs_type); return true; } } @@ -2076,13 +2086,15 @@ // Arithmetic expressions if (expr->IsArithmetic()) { // Binary arithmetic expressions over scalars - if (matching_types && lhs_type->IsAnyOf<I32, F32, U32>()) { + if (matching_types && lhs_type->is_numeric_scalar()) { + SetType(expr, lhs_type); return true; } // Binary arithmetic expressions over vectors if (matching_types && lhs_vec_elem_type && - lhs_vec_elem_type->IsAnyOf<I32, F32, U32>()) { + lhs_vec_elem_type->is_numeric_scalar()) { + SetType(expr, lhs_type); return true; } } @@ -2093,10 +2105,12 @@ // Multiplication of a vector and a scalar if (lhs_type->Is<F32>() && rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>()) { + SetType(expr, rhs_type); return true; } if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() && rhs_type->Is<F32>()) { + SetType(expr, lhs_type); return true; } @@ -2108,10 +2122,12 @@ // Multiplication of a matrix and a scalar if (lhs_type->Is<F32>() && rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>()) { + SetType(expr, rhs_type); return true; } if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() && rhs_type->Is<F32>()) { + SetType(expr, lhs_type); return true; } @@ -2119,6 +2135,8 @@ if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() && rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() && (lhs_vec->size() == rhs_mat->rows())) { + SetType(expr, builder_->create<sem::Vector>(lhs_vec->type(), + rhs_mat->columns())); return true; } @@ -2126,6 +2144,8 @@ if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() && rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>() && (lhs_mat->columns() == rhs_vec->size())) { + SetType(expr, + builder_->create<sem::Vector>(rhs_vec->type(), lhs_mat->rows())); return true; } @@ -2133,6 +2153,10 @@ if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() && (lhs_mat->columns() == rhs_mat->rows())) { + SetType(expr, builder_->create<sem::Matrix>( + builder_->create<sem::Vector>(lhs_mat_elem_type, + lhs_mat->rows()), + rhs_mat->columns())); return true; } } @@ -2142,11 +2166,13 @@ if (matching_types) { // Special case for bools: only == and != if (lhs_type->Is<Bool>() && (expr->IsEqual() || expr->IsNotEqual())) { + SetType(expr, builder_->create<sem::Bool>()); return true; } // For the rest, we can compare i32, u32, and f32 if (lhs_type->IsAnyOf<I32, U32, F32>()) { + SetType(expr, builder_->create<sem::Bool>()); return true; } } @@ -2155,10 +2181,14 @@ if (matching_vec_elem_types) { if (lhs_vec_elem_type->Is<Bool>() && (expr->IsEqual() || expr->IsNotEqual())) { + SetType(expr, builder_->create<sem::Vector>( + builder_->create<sem::Bool>(), lhs_vec->size())); return true; } - if (lhs_vec_elem_type->IsAnyOf<I32, U32, F32>()) { + if (lhs_vec_elem_type->is_numeric_scalar()) { + SetType(expr, builder_->create<sem::Vector>( + builder_->create<sem::Bool>(), lhs_vec->size())); return true; } } @@ -2167,6 +2197,7 @@ // Binary bitwise operations if (expr->IsBitwise()) { if (matching_types && lhs_type->IsAnyOf<I32, U32>()) { + SetType(expr, lhs_type); return true; } } @@ -2178,11 +2209,13 @@ // logical depending on lhs type). if (lhs_type->IsAnyOf<I32, U32>() && rhs_type->Is<U32>()) { + SetType(expr, lhs_type); return true; } if (lhs_vec_elem_type && lhs_vec_elem_type->IsAnyOf<I32, U32>() && rhs_vec_elem_type && rhs_vec_elem_type->Is<U32>()) { + SetType(expr, lhs_type); return true; } } @@ -2196,86 +2229,6 @@ return false; } -bool Resolver::Binary(ast::BinaryExpression* expr) { - Mark(expr->lhs()); - Mark(expr->rhs()); - if (!Expression(expr->lhs()) || !Expression(expr->rhs())) { - return false; - } - - if (!ValidateBinary(expr)) { - return false; - } - - // Result type matches first parameter type - if (expr->IsAnd() || expr->IsOr() || expr->IsXor() || expr->IsShiftLeft() || - expr->IsShiftRight() || expr->IsAdd() || expr->IsSubtract() || - expr->IsDivide() || expr->IsModulo()) { - SetType(expr, TypeOf(expr->lhs())->UnwrapRef()); - return true; - } - // Result type is a scalar or vector of boolean type - if (expr->IsLogicalAnd() || expr->IsLogicalOr() || expr->IsEqual() || - expr->IsNotEqual() || expr->IsLessThan() || expr->IsGreaterThan() || - expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) { - auto* bool_type = builder_->create<sem::Bool>(); - auto* param_type = TypeOf(expr->lhs())->UnwrapRef(); - sem::Type* result_type = bool_type; - if (auto* vec = param_type->As<sem::Vector>()) { - result_type = builder_->create<sem::Vector>(bool_type, vec->size()); - } - SetType(expr, result_type); - return true; - } - if (expr->IsMultiply()) { - auto* lhs_type = TypeOf(expr->lhs())->UnwrapRef(); - auto* rhs_type = TypeOf(expr->rhs())->UnwrapRef(); - - // Note, the ordering here matters. The later checks depend on the prior - // checks having been done. - auto* lhs_mat = lhs_type->As<sem::Matrix>(); - auto* rhs_mat = rhs_type->As<sem::Matrix>(); - auto* lhs_vec = lhs_type->As<sem::Vector>(); - auto* rhs_vec = rhs_type->As<sem::Vector>(); - const sem::Type* result_type = nullptr; - if (lhs_mat && rhs_mat) { - auto* column_type = - builder_->create<sem::Vector>(lhs_mat->type(), lhs_mat->rows()); - result_type = - builder_->create<sem::Matrix>(column_type, rhs_mat->columns()); - } else if (lhs_mat && rhs_vec) { - result_type = - builder_->create<sem::Vector>(lhs_mat->type(), lhs_mat->rows()); - } else if (lhs_vec && rhs_mat) { - result_type = - builder_->create<sem::Vector>(rhs_mat->type(), rhs_mat->columns()); - } else if (lhs_mat) { - // matrix * scalar - result_type = lhs_type; - } else if (rhs_mat) { - // scalar * matrix - result_type = rhs_type; - } else if (lhs_vec && rhs_vec) { - result_type = lhs_type; - } else if (lhs_vec) { - // Vector * scalar - result_type = lhs_type; - } else if (rhs_vec) { - // Scalar * vector - result_type = rhs_type; - } else { - // Scalar * Scalar - result_type = lhs_type; - } - - SetType(expr, result_type); - return true; - } - - diagnostics_.add_error("Unknown binary expression", expr->source()); - return false; -} - bool Resolver::UnaryOp(ast::UnaryOpExpression* unary) { Mark(unary->expr());
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index f002f97..8bf89ac 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h
@@ -236,7 +236,6 @@ uint32_t el_align, const Source& source); bool ValidateAssignment(const ast::AssignmentStatement* a); - bool ValidateBinary(ast::BinaryExpression* expr); bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info); bool ValidateFunction(const ast::Function* func, const FunctionInfo* info); bool ValidateGlobalVariable(const VariableInfo* var);