Add unary op expresison type determination.
This CL adds type determination for the unary op expression.
Bug: tint:5
Change-Id: I5b9c0c80bb48527f1f26febb2310f9640e5f7849
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/18850
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index c8d4952..51e8885 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -41,6 +41,7 @@
#include "src/ast/type/struct_type.h"
#include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h"
+#include "src/ast/unary_op_expression.h"
#include "src/ast/unary_derivative_expression.h"
#include "src/ast/unary_method_expression.h"
#include "src/ast/unless_statement.h"
@@ -220,6 +221,9 @@
if (expr->IsUnaryMethod()) {
return DetermineUnaryMethod(expr->AsUnaryMethod());
}
+ if (expr->IsUnaryOp()) {
+ return DetermineUnaryOp(expr->AsUnaryOp());
+ }
error_ = "unknown expression for type determination";
return false;
@@ -482,4 +486,13 @@
return true;
}
+bool TypeDeterminer::DetermineUnaryOp(ast::UnaryOpExpression* expr) {
+ // Result type matches the parameter type.
+ if (!DetermineResultType(expr->expr())) {
+ return false;
+ }
+ expr->set_result_type(expr->expr()->result_type());
+ return true;
+}
+
} // namespace tint
diff --git a/src/type_determiner.h b/src/type_determiner.h
index 126b755..0b72b59 100644
--- a/src/type_determiner.h
+++ b/src/type_determiner.h
@@ -36,6 +36,7 @@
class RelationalExpression;
class UnaryDerivativeExpression;
class UnaryMethodExpression;
+class UnaryOpExpression;
class Variable;
} // namespace ast
@@ -87,6 +88,7 @@
bool DetermineRelational(ast::RelationalExpression* expr);
bool DetermineUnaryDerivative(ast::UnaryDerivativeExpression* expr);
bool DetermineUnaryMethod(ast::UnaryMethodExpression* expr);
+ bool DetermineUnaryOp(ast::UnaryOpExpression* expr);
Context& ctx_;
std::string error_;
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index 1d63956..5e9179d 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -41,6 +41,7 @@
#include "src/ast/struct.h"
#include "src/ast/struct_member.h"
#include "src/ast/switch_statement.h"
+#include "src/ast/unary_op_expression.h"
#include "src/ast/type/array_type.h"
#include "src/ast/type/bool_type.h"
#include "src/ast/type/f32_type.h"
@@ -1428,5 +1429,38 @@
EXPECT_EQ(mat->columns(), 2);
}
+using UnaryOpExpressionTest = testing::TestWithParam<ast::UnaryOp>;
+TEST_P(UnaryOpExpressionTest, Expr_UnaryOp) {
+ auto op = GetParam();
+
+ ast::type::F32Type f32;
+
+ ast::type::VectorType vec4(&f32, 4);
+
+ auto var =
+ std::make_unique<ast::Variable>("ident", ast::StorageClass::kNone, &vec4);
+
+ ast::Module m;
+ m.AddGlobalVariable(std::move(var));
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+
+ // Register the global
+ EXPECT_TRUE(td.Determine(&m));
+
+ ast::UnaryOpExpression der(
+ op, std::make_unique<ast::IdentifierExpression>("ident"));
+ EXPECT_TRUE(td.DetermineResultType(&der));
+ ASSERT_NE(der.result_type(), nullptr);
+ ASSERT_TRUE(der.result_type()->IsVector());
+ EXPECT_TRUE(der.result_type()->AsVector()->type()->IsF32());
+ EXPECT_EQ(der.result_type()->AsVector()->size(), 4);
+}
+INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
+ UnaryOpExpressionTest,
+ testing::Values(ast::UnaryOp::kNegation,
+ ast::UnaryOp::kNot));
+
} // namespace
} // namespace tint