Add type determination for unary method.
This CL adds the type determination for the unary method expression.
Bug: tint:5
Change-Id: I9f94a79b9715cf74e37c74eb1a612ca84b3c241f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/18849
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index 8f07dc4..c8d4952 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -36,11 +36,13 @@
#include "src/ast/switch_statement.h"
#include "src/ast/type/array_type.h"
#include "src/ast/type/bool_type.h"
+#include "src/ast/type/f32_type.h"
#include "src/ast/type/matrix_type.h"
#include "src/ast/type/struct_type.h"
#include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h"
#include "src/ast/unary_derivative_expression.h"
+#include "src/ast/unary_method_expression.h"
#include "src/ast/unless_statement.h"
#include "src/ast/variable_decl_statement.h"
@@ -215,6 +217,9 @@
if (expr->IsUnaryDerivative()) {
return DetermineUnaryDerivative(expr->AsUnaryDerivative());
}
+ if (expr->IsUnaryMethod()) {
+ return DetermineUnaryMethod(expr->AsUnaryMethod());
+ }
error_ = "unknown expression for type determination";
return false;
@@ -415,4 +420,66 @@
return true;
}
+bool TypeDeterminer::DetermineUnaryMethod(ast::UnaryMethodExpression* expr) {
+ for (const auto& param : expr->params()) {
+ if (!DetermineResultType(param.get())) {
+ return false;
+ }
+ }
+
+ switch (expr->op()) {
+ case ast::UnaryMethod::kAny:
+ case ast::UnaryMethod::kAll: {
+ expr->set_result_type(
+ ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>()));
+ break;
+ }
+ case ast::UnaryMethod::kIsNan:
+ case ast::UnaryMethod::kIsInf:
+ case ast::UnaryMethod::kIsFinite:
+ case ast::UnaryMethod::kIsNormal: {
+ if (expr->params().empty()) {
+ error_ = "incorrect number of parameters";
+ return false;
+ }
+
+ auto bool_type =
+ ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
+ auto param_type = expr->params()[0]->result_type();
+ if (param_type->IsVector()) {
+ expr->set_result_type(
+ ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
+ bool_type, param_type->AsVector()->size())));
+ } else {
+ expr->set_result_type(bool_type);
+ }
+ break;
+ }
+ case ast::UnaryMethod::kDot: {
+ expr->set_result_type(
+ ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>()));
+ break;
+ }
+ case ast::UnaryMethod::kOuterProduct: {
+ if (expr->params().size() != 2) {
+ error_ = "incorrect number of parameters for outer product";
+ return false;
+ }
+ auto param0_type = expr->params()[0]->result_type();
+ auto param1_type = expr->params()[1]->result_type();
+ if (!param0_type->IsVector() || !param1_type->IsVector()) {
+ error_ = "invalid parameter type for outer product";
+ return false;
+ }
+ expr->set_result_type(
+ ctx_.type_mgr().Get(std::make_unique<ast::type::MatrixType>(
+ ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>()),
+ param0_type->AsVector()->size(),
+ param1_type->AsVector()->size())));
+ break;
+ }
+ }
+ return true;
+}
+
} // namespace tint
diff --git a/src/type_determiner.h b/src/type_determiner.h
index de6a188..126b755 100644
--- a/src/type_determiner.h
+++ b/src/type_determiner.h
@@ -35,6 +35,7 @@
class MemberAccessorExpression;
class RelationalExpression;
class UnaryDerivativeExpression;
+class UnaryMethodExpression;
class Variable;
} // namespace ast
@@ -85,6 +86,7 @@
bool DetermineMemberAccessor(ast::MemberAccessorExpression* expr);
bool DetermineRelational(ast::RelationalExpression* expr);
bool DetermineUnaryDerivative(ast::UnaryDerivativeExpression* expr);
+ bool DetermineUnaryMethod(ast::UnaryMethodExpression* expr);
Context& ctx_;
std::string error_;
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index a416916..1d63956 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -50,6 +50,7 @@
#include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h"
#include "src/ast/unary_derivative_expression.h"
+#include "src/ast/unary_method_expression.h"
#include "src/ast/unless_statement.h"
#include "src/ast/variable_decl_statement.h"
@@ -1268,5 +1269,164 @@
ast::UnaryDerivative::kDpdy,
ast::UnaryDerivative::kFwidth));
+using UnaryMethodExpressionBoolTest = testing::TestWithParam<ast::UnaryMethod>;
+TEST_P(UnaryMethodExpressionBoolTest, Expr_UnaryMethod_Any) {
+ auto op = GetParam();
+
+ ast::type::BoolType bool_type;
+ ast::type::VectorType vec3(&bool_type, 3);
+
+ auto var = std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone,
+ &vec3);
+
+ ast::Module m;
+ m.AddGlobalVariable(std::move(var));
+
+ ast::ExpressionList params;
+ params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
+
+ ast::UnaryMethodExpression exp(op, std::move(params));
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+
+ // Register the variable
+ EXPECT_TRUE(td.Determine(&m));
+
+ EXPECT_TRUE(td.DetermineResultType(&exp));
+ ASSERT_NE(exp.result_type(), nullptr);
+ EXPECT_TRUE(exp.result_type()->IsBool());
+}
+INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
+ UnaryMethodExpressionBoolTest,
+ testing::Values(ast::UnaryMethod::kAny,
+ ast::UnaryMethod::kAll));
+
+using UnaryMethodExpressionVecTest = testing::TestWithParam<ast::UnaryMethod>;
+TEST_P(UnaryMethodExpressionVecTest, Expr_UnaryMethod_Bool) {
+ auto op = GetParam();
+
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+
+ auto var = std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone,
+ &vec3);
+
+ ast::Module m;
+ m.AddGlobalVariable(std::move(var));
+
+ ast::ExpressionList params;
+ params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
+
+ ast::UnaryMethodExpression exp(op, std::move(params));
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+
+ // Register the variable
+ EXPECT_TRUE(td.Determine(&m));
+
+ EXPECT_TRUE(td.DetermineResultType(&exp));
+ ASSERT_NE(exp.result_type(), nullptr);
+ ASSERT_TRUE(exp.result_type()->IsVector());
+ EXPECT_TRUE(exp.result_type()->AsVector()->type()->IsBool());
+ EXPECT_EQ(exp.result_type()->AsVector()->size(), 3);
+}
+TEST_P(UnaryMethodExpressionVecTest, Expr_UnaryMethod_Vec) {
+ auto op = GetParam();
+
+ ast::type::F32Type f32;
+
+ auto var =
+ std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone, &f32);
+
+ ast::Module m;
+ m.AddGlobalVariable(std::move(var));
+
+ ast::ExpressionList params;
+ params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
+
+ ast::UnaryMethodExpression exp(op, std::move(params));
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+
+ // Register the variable
+ EXPECT_TRUE(td.Determine(&m));
+
+ EXPECT_TRUE(td.DetermineResultType(&exp));
+ ASSERT_NE(exp.result_type(), nullptr);
+ EXPECT_TRUE(exp.result_type()->IsBool());
+}
+INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
+ UnaryMethodExpressionVecTest,
+ testing::Values(ast::UnaryMethod::kIsInf,
+ ast::UnaryMethod::kIsNan,
+ ast::UnaryMethod::kIsFinite,
+ ast::UnaryMethod::kIsNormal));
+
+TEST_F(TypeDeterminerTest, Expr_UnaryMethod_Dot) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+
+ auto var = std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone,
+ &vec3);
+
+ ast::Module m;
+ m.AddGlobalVariable(std::move(var));
+
+ ast::ExpressionList params;
+ params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
+ params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
+
+ ast::UnaryMethodExpression exp(ast::UnaryMethod::kDot, std::move(params));
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+
+ // Register the variable
+ EXPECT_TRUE(td.Determine(&m));
+
+ EXPECT_TRUE(td.DetermineResultType(&exp));
+ ASSERT_NE(exp.result_type(), nullptr);
+ EXPECT_TRUE(exp.result_type()->IsF32());
+}
+
+TEST_F(TypeDeterminerTest, Expr_UnaryMethod_OuterProduct) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+ ast::type::VectorType vec2(&f32, 2);
+
+ auto var1 =
+ std::make_unique<ast::Variable>("v3", ast::StorageClass::kNone, &vec3);
+ auto var2 =
+ std::make_unique<ast::Variable>("v2", ast::StorageClass::kNone, &vec2);
+
+ ast::Module m;
+ m.AddGlobalVariable(std::move(var1));
+ m.AddGlobalVariable(std::move(var2));
+
+ ast::ExpressionList params;
+ params.push_back(std::make_unique<ast::IdentifierExpression>("v3"));
+ params.push_back(std::make_unique<ast::IdentifierExpression>("v2"));
+
+ ast::UnaryMethodExpression exp(ast::UnaryMethod::kOuterProduct,
+ std::move(params));
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+
+ // Register the variable
+ EXPECT_TRUE(td.Determine(&m));
+
+ EXPECT_TRUE(td.DetermineResultType(&exp));
+ ASSERT_NE(exp.result_type(), nullptr);
+ ASSERT_TRUE(exp.result_type()->IsMatrix());
+ auto mat = exp.result_type()->AsMatrix();
+ EXPECT_TRUE(mat->type()->IsF32());
+ EXPECT_EQ(mat->rows(), 3);
+ EXPECT_EQ(mat->columns(), 2);
+}
+
} // namespace
} // namespace tint