Resolver: merge binary expression validation and resolving logic together
This avoid duplicating the logic in two places, and makes it easier to
implement according to the spec.
Bug: tint:376
Change-Id: If62f508e2c76b5b661e66aae9ff20b8e874a65d8
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/52323
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 15e7f3e..3be647b 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -2035,7 +2035,14 @@
return true;
}
-bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
+bool Resolver::Binary(ast::BinaryExpression* expr) {
+ Mark(expr->lhs());
+ Mark(expr->rhs());
+
+ if (!Expression(expr->lhs()) || !Expression(expr->rhs())) {
+ return false;
+ }
+
using Bool = sem::Bool;
using F32 = sem::F32;
using I32 = sem::I32;
@@ -2061,14 +2068,17 @@
// Binary logical expressions
if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
if (matching_types && lhs_type->Is<Bool>()) {
+ SetType(expr, lhs_type);
return true;
}
}
if (expr->IsOr() || expr->IsAnd()) {
if (matching_types && lhs_type->Is<Bool>()) {
+ SetType(expr, lhs_type);
return true;
}
if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->Is<Bool>()) {
+ SetType(expr, lhs_type);
return true;
}
}
@@ -2076,13 +2086,15 @@
// Arithmetic expressions
if (expr->IsArithmetic()) {
// Binary arithmetic expressions over scalars
- if (matching_types && lhs_type->IsAnyOf<I32, F32, U32>()) {
+ if (matching_types && lhs_type->is_numeric_scalar()) {
+ SetType(expr, lhs_type);
return true;
}
// Binary arithmetic expressions over vectors
if (matching_types && lhs_vec_elem_type &&
- lhs_vec_elem_type->IsAnyOf<I32, F32, U32>()) {
+ lhs_vec_elem_type->is_numeric_scalar()) {
+ SetType(expr, lhs_type);
return true;
}
}
@@ -2093,10 +2105,12 @@
// Multiplication of a vector and a scalar
if (lhs_type->Is<F32>() && rhs_vec_elem_type &&
rhs_vec_elem_type->Is<F32>()) {
+ SetType(expr, rhs_type);
return true;
}
if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
rhs_type->Is<F32>()) {
+ SetType(expr, lhs_type);
return true;
}
@@ -2108,10 +2122,12 @@
// Multiplication of a matrix and a scalar
if (lhs_type->Is<F32>() && rhs_mat_elem_type &&
rhs_mat_elem_type->Is<F32>()) {
+ SetType(expr, rhs_type);
return true;
}
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
rhs_type->Is<F32>()) {
+ SetType(expr, lhs_type);
return true;
}
@@ -2119,6 +2135,8 @@
if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
(lhs_vec->size() == rhs_mat->rows())) {
+ SetType(expr, builder_->create<sem::Vector>(lhs_vec->type(),
+ rhs_mat->columns()));
return true;
}
@@ -2126,6 +2144,8 @@
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>() &&
(lhs_mat->columns() == rhs_vec->size())) {
+ SetType(expr,
+ builder_->create<sem::Vector>(rhs_vec->type(), lhs_mat->rows()));
return true;
}
@@ -2133,6 +2153,10 @@
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
(lhs_mat->columns() == rhs_mat->rows())) {
+ SetType(expr, builder_->create<sem::Matrix>(
+ builder_->create<sem::Vector>(lhs_mat_elem_type,
+ lhs_mat->rows()),
+ rhs_mat->columns()));
return true;
}
}
@@ -2142,11 +2166,13 @@
if (matching_types) {
// Special case for bools: only == and !=
if (lhs_type->Is<Bool>() && (expr->IsEqual() || expr->IsNotEqual())) {
+ SetType(expr, builder_->create<sem::Bool>());
return true;
}
// For the rest, we can compare i32, u32, and f32
if (lhs_type->IsAnyOf<I32, U32, F32>()) {
+ SetType(expr, builder_->create<sem::Bool>());
return true;
}
}
@@ -2155,10 +2181,14 @@
if (matching_vec_elem_types) {
if (lhs_vec_elem_type->Is<Bool>() &&
(expr->IsEqual() || expr->IsNotEqual())) {
+ SetType(expr, builder_->create<sem::Vector>(
+ builder_->create<sem::Bool>(), lhs_vec->size()));
return true;
}
- if (lhs_vec_elem_type->IsAnyOf<I32, U32, F32>()) {
+ if (lhs_vec_elem_type->is_numeric_scalar()) {
+ SetType(expr, builder_->create<sem::Vector>(
+ builder_->create<sem::Bool>(), lhs_vec->size()));
return true;
}
}
@@ -2167,6 +2197,7 @@
// Binary bitwise operations
if (expr->IsBitwise()) {
if (matching_types && lhs_type->IsAnyOf<I32, U32>()) {
+ SetType(expr, lhs_type);
return true;
}
}
@@ -2178,11 +2209,13 @@
// logical depending on lhs type).
if (lhs_type->IsAnyOf<I32, U32>() && rhs_type->Is<U32>()) {
+ SetType(expr, lhs_type);
return true;
}
if (lhs_vec_elem_type && lhs_vec_elem_type->IsAnyOf<I32, U32>() &&
rhs_vec_elem_type && rhs_vec_elem_type->Is<U32>()) {
+ SetType(expr, lhs_type);
return true;
}
}
@@ -2196,86 +2229,6 @@
return false;
}
-bool Resolver::Binary(ast::BinaryExpression* expr) {
- Mark(expr->lhs());
- Mark(expr->rhs());
- if (!Expression(expr->lhs()) || !Expression(expr->rhs())) {
- return false;
- }
-
- if (!ValidateBinary(expr)) {
- return false;
- }
-
- // Result type matches first parameter type
- if (expr->IsAnd() || expr->IsOr() || expr->IsXor() || expr->IsShiftLeft() ||
- expr->IsShiftRight() || expr->IsAdd() || expr->IsSubtract() ||
- expr->IsDivide() || expr->IsModulo()) {
- SetType(expr, TypeOf(expr->lhs())->UnwrapRef());
- return true;
- }
- // Result type is a scalar or vector of boolean type
- if (expr->IsLogicalAnd() || expr->IsLogicalOr() || expr->IsEqual() ||
- expr->IsNotEqual() || expr->IsLessThan() || expr->IsGreaterThan() ||
- expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) {
- auto* bool_type = builder_->create<sem::Bool>();
- auto* param_type = TypeOf(expr->lhs())->UnwrapRef();
- sem::Type* result_type = bool_type;
- if (auto* vec = param_type->As<sem::Vector>()) {
- result_type = builder_->create<sem::Vector>(bool_type, vec->size());
- }
- SetType(expr, result_type);
- return true;
- }
- if (expr->IsMultiply()) {
- auto* lhs_type = TypeOf(expr->lhs())->UnwrapRef();
- auto* rhs_type = TypeOf(expr->rhs())->UnwrapRef();
-
- // Note, the ordering here matters. The later checks depend on the prior
- // checks having been done.
- auto* lhs_mat = lhs_type->As<sem::Matrix>();
- auto* rhs_mat = rhs_type->As<sem::Matrix>();
- auto* lhs_vec = lhs_type->As<sem::Vector>();
- auto* rhs_vec = rhs_type->As<sem::Vector>();
- const sem::Type* result_type = nullptr;
- if (lhs_mat && rhs_mat) {
- auto* column_type =
- builder_->create<sem::Vector>(lhs_mat->type(), lhs_mat->rows());
- result_type =
- builder_->create<sem::Matrix>(column_type, rhs_mat->columns());
- } else if (lhs_mat && rhs_vec) {
- result_type =
- builder_->create<sem::Vector>(lhs_mat->type(), lhs_mat->rows());
- } else if (lhs_vec && rhs_mat) {
- result_type =
- builder_->create<sem::Vector>(rhs_mat->type(), rhs_mat->columns());
- } else if (lhs_mat) {
- // matrix * scalar
- result_type = lhs_type;
- } else if (rhs_mat) {
- // scalar * matrix
- result_type = rhs_type;
- } else if (lhs_vec && rhs_vec) {
- result_type = lhs_type;
- } else if (lhs_vec) {
- // Vector * scalar
- result_type = lhs_type;
- } else if (rhs_vec) {
- // Scalar * vector
- result_type = rhs_type;
- } else {
- // Scalar * Scalar
- result_type = lhs_type;
- }
-
- SetType(expr, result_type);
- return true;
- }
-
- diagnostics_.add_error("Unknown binary expression", expr->source());
- return false;
-}
-
bool Resolver::UnaryOp(ast::UnaryOpExpression* unary) {
Mark(unary->expr());
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index f002f97..8bf89ac 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -236,7 +236,6 @@
uint32_t el_align,
const Source& source);
bool ValidateAssignment(const ast::AssignmentStatement* a);
- bool ValidateBinary(ast::BinaryExpression* expr);
bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info);
bool ValidateFunction(const ast::Function* func, const FunctionInfo* info);
bool ValidateGlobalVariable(const VariableInfo* var);