Add type determination for member accessor.
This Cl adds the member accessor type determination for both structures
and vector swizzles.
Bug: tint:5
Change-Id: I1172db29d8cbed2d9e0ae228ebc3a818d4930b7f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/18846
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/ast/relational_expression.h b/src/ast/relational_expression.h
index 04b9064..f5bd944 100644
--- a/src/ast/relational_expression.h
+++ b/src/ast/relational_expression.h
@@ -48,7 +48,7 @@
kModulo,
};
-/// A Relational Expression
+/// An xor expression
class RelationalExpression : public Expression {
public:
/// Constructor
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index f1da09c..9a81e19 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -28,12 +28,14 @@
#include "src/ast/identifier_expression.h"
#include "src/ast/if_statement.h"
#include "src/ast/loop_statement.h"
+#include "src/ast/member_accessor_expression.h"
#include "src/ast/regardless_statement.h"
#include "src/ast/return_statement.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/switch_statement.h"
#include "src/ast/type/array_type.h"
#include "src/ast/type/matrix_type.h"
+#include "src/ast/type/struct_type.h"
#include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h"
#include "src/ast/unless_statement.h"
@@ -41,7 +43,10 @@
namespace tint {
-TypeDeterminer::TypeDeterminer(Context* ctx) : ctx_(*ctx) {}
+TypeDeterminer::TypeDeterminer(Context* ctx) : ctx_(*ctx) {
+ // TODO(dsinclair): Temporary usage to avoid compiler warning
+ static_cast<void>(ctx_.type_mgr());
+}
TypeDeterminer::~TypeDeterminer() = default;
@@ -174,15 +179,6 @@
return false;
}
-bool TypeDeterminer::DetermineResultType(const ast::ExpressionList& exprs) {
- for (const auto& expr : exprs) {
- if (!DetermineResultType(expr.get())) {
- return false;
- }
- }
- return true;
-}
-
bool TypeDeterminer::DetermineResultType(ast::Expression* expr) {
// This is blindly called above, so in some cases the expression won't exist.
if (!expr) {
@@ -207,6 +203,9 @@
if (expr->IsIdentifier()) {
return DetermineIdentifier(expr->AsIdentifier());
}
+ if (expr->IsMemberAccessor()) {
+ return DetermineMemberAccessor(expr->AsMemberAccessor());
+ }
error_ = "unknown expression for type determination";
return false;
@@ -242,9 +241,6 @@
if (!DetermineResultType(expr->func())) {
return false;
}
- if (!DetermineResultType(expr->params())) {
- return false;
- }
expr->set_result_type(expr->func()->result_type());
return true;
}
@@ -283,7 +279,45 @@
return true;
}
- error_ = "unknown identifier for type determination";
+ return true;
+}
+
+bool TypeDeterminer::DetermineMemberAccessor(
+ ast::MemberAccessorExpression* expr) {
+ if (!DetermineResultType(expr->structure())) {
+ return false;
+ }
+
+ auto data_type = expr->structure()->result_type();
+ if (data_type->IsStruct()) {
+ auto strct = data_type->AsStruct()->impl();
+ auto name = expr->member()->name()[0];
+
+ for (const auto& member : strct->members()) {
+ if (member->name() != name) {
+ continue;
+ }
+
+ expr->set_result_type(member->type());
+ return true;
+ }
+
+ error_ = "struct member not found";
+ return false;
+ }
+ if (data_type->IsVector()) {
+ auto vec = data_type->AsVector();
+
+ // The vector will have a number of components equal to the length of the
+ // swizzle. This assumes the validator will check that the swizzle
+ // is correct.
+ expr->set_result_type(
+ ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
+ vec->type(), expr->member()->name()[0].size())));
+ return true;
+ }
+
+ error_ = "invalid type in member accessor";
return false;
}
diff --git a/src/type_determiner.h b/src/type_determiner.h
index 692947e..34fcaf3 100644
--- a/src/type_determiner.h
+++ b/src/type_determiner.h
@@ -31,6 +31,7 @@
class CastExpression;
class ConstructorExpression;
class IdentifierExpression;
+class MemberAccessorExpression;
class Function;
class Variable;
@@ -67,10 +68,6 @@
/// @param stmt the statement to check
/// @returns true if the determination was successful
bool DetermineResultType(ast::Statement* stmt);
- /// Determines type information for a list of expressions
- /// @param exprs the expressions to check
- /// @returns true if the determination was successful
- bool DetermineResultType(const ast::ExpressionList& exprs);
/// Determines type information for an expression
/// @param expr the expression to check
/// @returns true if the determination was successful
@@ -83,6 +80,8 @@
bool DetermineCast(ast::CastExpression* expr);
bool DetermineConstructor(ast::ConstructorExpression* expr);
bool DetermineIdentifier(ast::IdentifierExpression* expr);
+ bool DetermineMemberAccessor(ast::MemberAccessorExpression* expr);
+
Context& ctx_;
std::string error_;
ScopeStack<ast::Variable*> variable_stack_;
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index 7796ca7..486246b 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -1,3 +1,4 @@
+
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
@@ -32,14 +33,18 @@
#include "src/ast/if_statement.h"
#include "src/ast/int_literal.h"
#include "src/ast/loop_statement.h"
+#include "src/ast/member_accessor_expression.h"
#include "src/ast/regardless_statement.h"
#include "src/ast/return_statement.h"
#include "src/ast/scalar_constructor_expression.h"
+#include "src/ast/struct.h"
+#include "src/ast/struct_member.h"
#include "src/ast/switch_statement.h"
#include "src/ast/type/array_type.h"
#include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h"
#include "src/ast/type/matrix_type.h"
+#include "src/ast/type/struct_type.h"
#include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h"
#include "src/ast/unless_statement.h"
@@ -512,45 +517,6 @@
EXPECT_TRUE(call.result_type()->IsF32());
}
-TEST_F(TypeDeterminerTest, Expr_Call_WithParams) {
- ast::type::F32Type f32;
- ast::type::I32Type i32;
-
- ast::VariableList params;
- params.push_back(
- std::make_unique<ast::Variable>("a", ast::StorageClass::kNone, &f32));
- params.push_back(
- std::make_unique<ast::Variable>("b", ast::StorageClass::kNone, &i32));
-
- auto func =
- std::make_unique<ast::Function>("my_func", std::move(params), &f32);
- ast::Module m;
- m.AddFunction(std::move(func));
-
- // Register the function
- EXPECT_TRUE(td()->Determine(&m));
-
- ast::ExpressionList call_params;
- call_params.push_back(std::make_unique<ast::ScalarConstructorExpression>(
- std::make_unique<ast::FloatLiteral>(&f32, 2.5f)));
- auto a_ptr = call_params.back().get();
- call_params.push_back(std::make_unique<ast::ScalarConstructorExpression>(
- std::make_unique<ast::IntLiteral>(&i32, 1)));
- auto b_ptr = call_params.back().get();
-
- ast::CallExpression call(
- std::make_unique<ast::IdentifierExpression>("my_func"),
- std::move(call_params));
- EXPECT_TRUE(td()->DetermineResultType(&call));
- ASSERT_NE(call.result_type(), nullptr);
- EXPECT_TRUE(call.result_type()->IsF32());
-
- ASSERT_NE(a_ptr->result_type(), nullptr);
- EXPECT_TRUE(a_ptr->result_type()->IsF32());
- ASSERT_NE(b_ptr->result_type(), nullptr);
- EXPECT_TRUE(b_ptr->result_type()->IsI32());
-}
-
TEST_F(TypeDeterminerTest, Expr_Cast) {
ast::type::F32Type f32;
ast::CastExpression cast(&f32,
@@ -651,5 +617,143 @@
EXPECT_TRUE(ident.result_type()->IsF32());
}
+TEST_F(TypeDeterminerTest, Expr_MemberAccessor_Struct) {
+ ast::type::I32Type i32;
+ ast::type::F32Type f32;
+
+ ast::StructMemberDecorationList decos;
+ ast::StructMemberList members;
+ members.push_back(std::make_unique<ast::StructMember>("first_member", &i32,
+ std::move(decos)));
+ members.push_back(std::make_unique<ast::StructMember>("second_member", &f32,
+ std::move(decos)));
+
+ auto strct = std::make_unique<ast::Struct>(ast::StructDecoration::kNone,
+ std::move(members));
+
+ ast::type::StructType st(std::move(strct));
+
+ auto var = std::make_unique<ast::Variable>("my_struct",
+ ast::StorageClass::kNone, &st);
+
+ ast::Module m;
+ m.AddGlobalVariable(std::move(var));
+
+ // Register the global
+ EXPECT_TRUE(td()->Determine(&m));
+
+ auto ident = std::make_unique<ast::IdentifierExpression>("my_struct");
+ auto mem_ident = std::make_unique<ast::IdentifierExpression>("second_member");
+
+ ast::MemberAccessorExpression mem(std::move(ident), std::move(mem_ident));
+ EXPECT_TRUE(td()->DetermineResultType(&mem));
+ ASSERT_NE(mem.result_type(), nullptr);
+ EXPECT_TRUE(mem.result_type()->IsF32());
+}
+
+TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+
+ auto var = std::make_unique<ast::Variable>("my_vec", ast::StorageClass::kNone,
+ &vec3);
+ ast::Module m;
+ m.AddGlobalVariable(std::move(var));
+
+ // Register the global
+ EXPECT_TRUE(td()->Determine(&m));
+
+ auto ident = std::make_unique<ast::IdentifierExpression>("my_vec");
+ auto swizzle = std::make_unique<ast::IdentifierExpression>("xy");
+
+ ast::MemberAccessorExpression mem(std::move(ident), std::move(swizzle));
+ EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error();
+ ASSERT_NE(mem.result_type(), nullptr);
+ ASSERT_TRUE(mem.result_type()->IsVector());
+ EXPECT_TRUE(mem.result_type()->AsVector()->type()->IsF32());
+ EXPECT_EQ(mem.result_type()->AsVector()->size(), 2);
+}
+
+TEST_F(TypeDeterminerTest, Expr_MultiLevel) {
+ // struct b {
+ // vec4<f32> foo
+ // }
+ // struct A {
+ // vec3<struct b> mem
+ // }
+ // var c : A
+ // c.mem[0].foo.yx
+ // -> vec2<f32>
+ //
+ // MemberAccessor{
+ // MemberAccessor{
+ // ArrayAccessor{
+ // MemberAccessor{
+ // Identifier{c}
+ // Identifier{mem}
+ // }
+ // ScalarConstructor{0}
+ // }
+ // Identifier{foo}
+ // }
+ // Identifier{yx}
+ // }
+ //
+ ast::type::I32Type i32;
+ ast::type::F32Type f32;
+
+ ast::type::VectorType vec4(&f32, 4);
+
+ ast::StructMemberDecorationList decos;
+ ast::StructMemberList b_members;
+ b_members.push_back(
+ std::make_unique<ast::StructMember>("foo", &vec4, std::move(decos)));
+
+ auto strctB = std::make_unique<ast::Struct>(ast::StructDecoration::kNone,
+ std::move(b_members));
+ ast::type::StructType stB(std::move(strctB));
+
+ ast::type::VectorType vecB(&stB, 3);
+
+ ast::StructMemberList a_members;
+ a_members.push_back(
+ std::make_unique<ast::StructMember>("mem", &vecB, std::move(decos)));
+
+ auto strctA = std::make_unique<ast::Struct>(ast::StructDecoration::kNone,
+ std::move(a_members));
+
+ ast::type::StructType stA(std::move(strctA));
+
+ auto var =
+ std::make_unique<ast::Variable>("c", ast::StorageClass::kNone, &stA);
+
+ ast::Module m;
+ m.AddGlobalVariable(std::move(var));
+
+ // Register the global
+ EXPECT_TRUE(td()->Determine(&m));
+
+ auto ident = std::make_unique<ast::IdentifierExpression>("c");
+ auto mem_ident = std::make_unique<ast::IdentifierExpression>("mem");
+ auto foo_ident = std::make_unique<ast::IdentifierExpression>("foo");
+ auto idx = std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::IntLiteral>(&i32, 0));
+ auto swizzle = std::make_unique<ast::IdentifierExpression>("yx");
+
+ ast::MemberAccessorExpression mem(
+ std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::ArrayAccessorExpression>(
+ std::make_unique<ast::MemberAccessorExpression>(
+ std::move(ident), std::move(mem_ident)),
+ std::move(idx)),
+ std::move(foo_ident)),
+ std::move(swizzle));
+ EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error();
+ ASSERT_NE(mem.result_type(), nullptr);
+ ASSERT_TRUE(mem.result_type()->IsVector());
+ EXPECT_TRUE(mem.result_type()->AsVector()->type()->IsF32());
+ EXPECT_EQ(mem.result_type()->AsVector()->size(), 2);
+}
+
} // namespace
} // namespace tint