semantic: Add Function::Parameters(), MemberAccessorExpression subtypes
Add semantic::Swizzle and semantic::StructMemberAccess, both deriving from MemberAccessorExpression
Add semantic::Function::Parameters() to list the semantic::Variable parameters for the function.
Change-Id: I8cc69f3738380c14f61d051ee2989be6194d148d
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47220
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index c9110a4..06461e5 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -534,7 +534,9 @@
variable_stack_.push_scope();
for (auto* param : func->params()) {
- variable_stack_.set(param->symbol(), CreateVariableInfo(param));
+ auto* param_info = CreateVariableInfo(param);
+ variable_stack_.set(param->symbol(), param_info);
+ func_info->parameters.emplace_back(param_info);
if (!ApplyStorageClassUsageToType(param->declared_storage_class(),
param->declared_type(),
@@ -1171,12 +1173,14 @@
std::vector<uint32_t> swizzle;
if (auto* ty = data_type->As<type::Struct>()) {
- auto* strct = ty->impl();
+ auto* str = Structure(ty);
auto symbol = expr->member()->symbol();
- for (auto* member : strct->members()) {
- if (member->symbol() == symbol) {
- ret = member->type();
+ const semantic::StructMember* member = nullptr;
+ for (auto* m : str->members) {
+ if (m->Declaration()->symbol() == symbol) {
+ ret = m->Declaration()->type();
+ member = m;
break;
}
}
@@ -1192,6 +1196,9 @@
if (auto* ptr = res->As<type::Pointer>()) {
ret = builder_->create<type::Pointer>(ret, ptr->storage_class());
}
+
+ builder_->Sem().Add(expr, builder_->create<semantic::StructMemberAccess>(
+ expr, ret, current_statement_, member));
} else if (auto* vec = data_type->As<type::Vector>()) {
std::string str = builder_->Symbols().NameFor(expr->member()->symbol());
auto size = str.size();
@@ -1257,6 +1264,9 @@
ret = builder_->create<type::Vector>(vec->type(),
static_cast<uint32_t>(size));
}
+ builder_->Sem().Add(
+ expr, builder_->create<semantic::Swizzle>(expr, ret, current_statement_,
+ std::move(swizzle)));
} else {
diagnostics_.add_error(
"invalid use of member accessor on a non-vector/non-struct " +
@@ -1265,9 +1275,6 @@
return false;
}
- builder_->Sem().Add(expr,
- builder_->create<semantic::MemberAccessorExpression>(
- expr, ret, current_statement_, std::move(swizzle)));
SetType(expr, ret);
return true;
@@ -1682,7 +1689,8 @@
auto* info = it.second;
auto* sem_func = builder_->create<semantic::Function>(
- info->declaration, remap_vars(info->referenced_module_vars),
+ info->declaration, remap_vars(info->parameters),
+ remap_vars(info->referenced_module_vars),
remap_vars(info->local_referenced_module_vars), info->return_statements,
ancestor_entry_points[func->symbol()]);
func_info_to_sem_func.emplace(info, sem_func);
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index 76c99c3..e7f1f76 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -106,6 +106,7 @@
~FunctionInfo();
ast::Function* const declaration;
+ std::vector<VariableInfo*> parameters;
UniqueVector<VariableInfo*> referenced_module_vars;
UniqueVector<VariableInfo*> local_referenced_module_vars;
std::vector<const ast::ReturnStatement*> return_statements;
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index 470adab..389fcec 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -736,6 +736,32 @@
EXPECT_FALSE(r()->Resolve());
}
+TEST_F(ResolverTest, Function_Parameters) {
+ auto* param_a = Param("a", ty.f32());
+ auto* param_b = Param("b", ty.i32());
+ auto* param_c = Param("c", ty.u32());
+
+ auto* func = Func("my_func",
+ ast::VariableList{
+ param_a,
+ param_b,
+ param_c,
+ },
+ ty.void_(), {});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* func_sem = Sem().Get(func);
+ ASSERT_NE(func_sem, nullptr);
+ EXPECT_EQ(func_sem->Parameters().size(), 3u);
+ EXPECT_EQ(func_sem->Parameters()[0]->Type(), ty.f32());
+ EXPECT_EQ(func_sem->Parameters()[1]->Type(), ty.i32());
+ EXPECT_EQ(func_sem->Parameters()[2]->Type(), ty.u32());
+ EXPECT_EQ(func_sem->Parameters()[0]->Declaration(), param_a);
+ EXPECT_EQ(func_sem->Parameters()[1]->Declaration(), param_b);
+ EXPECT_EQ(func_sem->Parameters()[2]->Declaration(), param_c);
+}
+
TEST_F(ResolverTest, Function_RegisterInputOutputVariables) {
auto* in_var = Global("in_var", ty.f32(), ast::StorageClass::kInput);
auto* out_var = Global("out_var", ty.f32(), ast::StorageClass::kOutput);
@@ -757,6 +783,7 @@
auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr);
+ EXPECT_EQ(func_sem->Parameters().size(), 0u);
const auto& vars = func_sem->ReferencedModuleVariables();
ASSERT_EQ(vars.size(), 5u);
@@ -794,6 +821,7 @@
auto* func2_sem = Sem().Get(func2);
ASSERT_NE(func2_sem, nullptr);
+ EXPECT_EQ(func2_sem->Parameters().size(), 0u);
const auto& vars = func2_sem->ReferencedModuleVariables();
ASSERT_EQ(vars.size(), 5u);
@@ -842,6 +870,7 @@
auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr);
+ EXPECT_EQ(func_sem->Parameters().size(), 0u);
EXPECT_EQ(func_sem->ReturnStatements().size(), 2u);
EXPECT_EQ(func_sem->ReturnStatements()[0], ret_1);
@@ -867,6 +896,14 @@
auto* ptr = TypeOf(mem)->As<type::Pointer>();
EXPECT_TRUE(ptr->type()->Is<type::F32>());
+ ASSERT_TRUE(Sem().Get(mem)->Is<semantic::StructMemberAccess>());
+ EXPECT_EQ(Sem()
+ .Get(mem)
+ ->As<semantic::StructMemberAccess>()
+ ->Member()
+ ->Declaration()
+ ->symbol(),
+ Symbols().Get("second_member"));
}
TEST_F(ResolverTest, Expr_MemberAccessor_Struct_Alias) {
@@ -889,6 +926,7 @@
auto* ptr = TypeOf(mem)->As<type::Pointer>();
EXPECT_TRUE(ptr->type()->Is<type::F32>());
+ ASSERT_TRUE(Sem().Get(mem)->Is<semantic::StructMemberAccess>());
}
TEST_F(ResolverTest, Expr_MemberAccessor_VectorSwizzle) {
@@ -903,7 +941,9 @@
ASSERT_TRUE(TypeOf(mem)->Is<type::Vector>());
EXPECT_TRUE(TypeOf(mem)->As<type::Vector>()->type()->Is<type::F32>());
EXPECT_EQ(TypeOf(mem)->As<type::Vector>()->size(), 4u);
- EXPECT_THAT(Sem().Get(mem)->Swizzle(), ElementsAre(0, 2, 1, 3));
+ ASSERT_TRUE(Sem().Get(mem)->Is<semantic::Swizzle>());
+ EXPECT_THAT(Sem().Get(mem)->As<semantic::Swizzle>()->Indices(),
+ ElementsAre(0, 2, 1, 3));
}
TEST_F(ResolverTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
@@ -919,7 +959,9 @@
auto* ptr = TypeOf(mem)->As<type::Pointer>();
ASSERT_TRUE(ptr->type()->Is<type::F32>());
- EXPECT_THAT(Sem().Get(mem)->Swizzle(), ElementsAre(2));
+ ASSERT_TRUE(Sem().Get(mem)->Is<semantic::Swizzle>());
+ EXPECT_THAT(Sem().Get(mem)->As<semantic::Swizzle>()->Indices(),
+ ElementsAre(2));
}
TEST_F(ResolverTest, Expr_Accessor_MultiLevel) {
@@ -971,6 +1013,7 @@
ASSERT_TRUE(TypeOf(mem)->Is<type::Vector>());
EXPECT_TRUE(TypeOf(mem)->As<type::Vector>()->type()->Is<type::F32>());
EXPECT_EQ(TypeOf(mem)->As<type::Vector>()->size(), 2u);
+ ASSERT_TRUE(Sem().Get(mem)->Is<semantic::Swizzle>());
}
TEST_F(ResolverTest, Expr_MemberAccessor_InBinaryOp) {
@@ -1502,6 +1545,10 @@
ASSERT_NE(ep_1_sem, nullptr);
ASSERT_NE(ep_2_sem, nullptr);
+ EXPECT_EQ(func_b_sem->Parameters().size(), 0u);
+ EXPECT_EQ(func_a_sem->Parameters().size(), 0u);
+ EXPECT_EQ(func_c_sem->Parameters().size(), 0u);
+
const auto& b_eps = func_b_sem->AncestorEntryPoints();
ASSERT_EQ(2u, b_eps.size());
EXPECT_EQ(Symbols().Register("ep_1"), b_eps[0]);
diff --git a/src/semantic/function.h b/src/semantic/function.h
index efc6fbc..6301374 100644
--- a/src/semantic/function.h
+++ b/src/semantic/function.h
@@ -46,12 +46,14 @@
/// Constructor
/// @param declaration the ast::Function
+ /// @param parameters the parameters to the function
/// @param referenced_module_vars the referenced module variables
/// @param local_referenced_module_vars the locally referenced module
/// @param return_statements the function return statements
/// variables
/// @param ancestor_entry_points the ancestor entry points
Function(ast::Function* declaration,
+ std::vector<const Variable*> parameters,
std::vector<const Variable*> referenced_module_vars,
std::vector<const Variable*> local_referenced_module_vars,
std::vector<const ast::ReturnStatement*> return_statements,
@@ -63,6 +65,9 @@
/// @returns the ast::Function declaration
ast::Function* Declaration() const { return declaration_; }
+ /// @return the parameters to the function
+ const std::vector<const Variable*> Parameters() const { return parameters_; }
+
/// Note: If this function calls other functions, the return will also include
/// all of the referenced variables from the callees.
/// @returns the referenced module variables
@@ -147,6 +152,7 @@
bool multisampled) const;
ast::Function* const declaration_;
+ std::vector<const Variable*> const parameters_;
std::vector<const Variable*> const referenced_module_vars_;
std::vector<const Variable*> const local_referenced_module_vars_;
std::vector<const ast::ReturnStatement*> const return_statements_;
diff --git a/src/semantic/member_accessor_expression.h b/src/semantic/member_accessor_expression.h
index 71268b4..4865b77 100644
--- a/src/semantic/member_accessor_expression.h
+++ b/src/semantic/member_accessor_expression.h
@@ -20,8 +20,18 @@
#include "src/semantic/expression.h"
namespace tint {
+
+/// Forward declarations
+namespace ast {
+class MemberAccessorExpression;
+} // namespace ast
+
namespace semantic {
+/// Forward declarations
+class Struct;
+class StructMember;
+
/// MemberAccessorExpression holds the semantic information for a
/// ast::MemberAccessorExpression node.
class MemberAccessorExpression
@@ -31,24 +41,60 @@
/// @param declaration the AST node
/// @param type the resolved type of the expression
/// @param statement the statement that owns this expression
- /// @param swizzle if this member access is for a vector swizzle, the swizzle
- /// indices
- MemberAccessorExpression(ast::Expression* declaration,
+ MemberAccessorExpression(ast::MemberAccessorExpression* declaration,
type::Type* type,
- Statement* statement,
- std::vector<uint32_t> swizzle);
+ Statement* statement);
/// Destructor
~MemberAccessorExpression() override;
+};
- /// @return true if this member access is for a vector swizzle
- bool IsSwizzle() const { return !swizzle_.empty(); }
+/// StructMemberAccess holds the semantic information for a
+/// ast::MemberAccessorExpression node that represents an access to a structure
+/// member.
+class StructMemberAccess
+ : public Castable<StructMemberAccess, MemberAccessorExpression> {
+ public:
+ /// Constructor
+ /// @param declaration the AST node
+ /// @param type the resolved type of the expression
+ /// @param member the structure member
+ StructMemberAccess(ast::MemberAccessorExpression* declaration,
+ type::Type* type,
+ Statement* statement,
+ const StructMember* member);
- /// @return the swizzle indices, if this is a vector swizzle
- const std::vector<uint32_t>& Swizzle() const { return swizzle_; }
+ /// Destructor
+ ~StructMemberAccess() override;
+
+ /// @returns the structure member
+ StructMember const* Member() const { return member_; }
private:
- std::vector<uint32_t> const swizzle_;
+ StructMember const* const member_;
+};
+
+/// Swizzle holds the semantic information for a ast::MemberAccessorExpression
+/// node that represents a vector swizzle.
+class Swizzle : public Castable<Swizzle, MemberAccessorExpression> {
+ public:
+ /// Constructor
+ /// @param declaration the AST node
+ /// @param type the resolved type of the expression
+ /// @param indices the swizzle indices
+ Swizzle(ast::MemberAccessorExpression* declaration,
+ type::Type* type,
+ Statement* statement,
+ std::vector<uint32_t> indices);
+
+ /// Destructor
+ ~Swizzle() override;
+
+ /// @return the swizzle indices, if this is a vector swizzle
+ const std::vector<uint32_t>& Indices() const { return indices_; }
+
+ private:
+ std::vector<uint32_t> const indices_;
};
} // namespace semantic
diff --git a/src/semantic/sem_function.cc b/src/semantic/sem_function.cc
index dab7f1c..97e94af 100644
--- a/src/semantic/sem_function.cc
+++ b/src/semantic/sem_function.cc
@@ -41,12 +41,14 @@
} // namespace
Function::Function(ast::Function* declaration,
+ std::vector<const Variable*> parameters,
std::vector<const Variable*> referenced_module_vars,
std::vector<const Variable*> local_referenced_module_vars,
std::vector<const ast::ReturnStatement*> return_statements,
std::vector<Symbol> ancestor_entry_points)
: Base(declaration->return_type(), GetParameters(declaration)),
declaration_(declaration),
+ parameters_(std::move(parameters)),
referenced_module_vars_(std::move(referenced_module_vars)),
local_referenced_module_vars_(std::move(local_referenced_module_vars)),
return_statements_(std::move(return_statements)),
diff --git a/src/semantic/sem_member_accessor_expression.cc b/src/semantic/sem_member_accessor_expression.cc
index 00488cc..d0d4968 100644
--- a/src/semantic/sem_member_accessor_expression.cc
+++ b/src/semantic/sem_member_accessor_expression.cc
@@ -12,21 +12,40 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "src/ast/member_accessor_expression.h"
#include "src/semantic/member_accessor_expression.h"
TINT_INSTANTIATE_TYPEINFO(tint::semantic::MemberAccessorExpression);
+TINT_INSTANTIATE_TYPEINFO(tint::semantic::StructMemberAccess);
+TINT_INSTANTIATE_TYPEINFO(tint::semantic::Swizzle);
namespace tint {
namespace semantic {
MemberAccessorExpression::MemberAccessorExpression(
- ast::Expression* declaration,
+ ast::MemberAccessorExpression* declaration,
type::Type* type,
- Statement* statement,
- std::vector<uint32_t> swizzle)
- : Base(declaration, type, statement), swizzle_(std::move(swizzle)) {}
+ Statement* statement)
+ : Base(declaration, type, statement) {}
MemberAccessorExpression::~MemberAccessorExpression() = default;
+StructMemberAccess::StructMemberAccess(
+ ast::MemberAccessorExpression* declaration,
+ type::Type* type,
+ Statement* statement,
+ const StructMember* member)
+ : Base(declaration, type, statement), member_(member) {}
+
+StructMemberAccess::~StructMemberAccess() = default;
+
+Swizzle::Swizzle(ast::MemberAccessorExpression* declaration,
+ type::Type* type,
+ Statement* statement,
+ std::vector<uint32_t> indices)
+ : Base(declaration, type, statement), indices_(std::move(indices)) {}
+
+Swizzle::~Swizzle() = default;
+
} // namespace semantic
} // namespace tint
diff --git a/src/transform/decompose_storage_access.cc b/src/transform/decompose_storage_access.cc
index 7b1a639..dc38ba4 100644
--- a/src/transform/decompose_storage_access.cc
+++ b/src/transform/decompose_storage_access.cc
@@ -644,27 +644,12 @@
if (auto* accessor = node->As<ast::MemberAccessorExpression>()) {
// X.Y
auto* accessor_sem = sem.Get(accessor);
- auto swizzle = accessor_sem->Swizzle();
- switch (swizzle.size()) {
- case 0: {
- if (auto access = state.TakeAccess(accessor->structure())) {
- auto* str_ty = access.type->As<type::Struct>();
- auto* member =
- sem.Get(str_ty)->FindMember(accessor->member()->symbol());
- auto offset = member->Offset();
- state.AddAccesss(
- accessor, {
- access.var,
- Add(std::move(access.offset), std::move(offset)),
- member->Declaration()->type()->UnwrapAll(),
- });
- }
- break;
- }
- case 1: {
+ if (auto* swizzle = accessor_sem->As<semantic::Swizzle>()) {
+ if (swizzle->Indices().size() == 1) {
if (auto access = state.TakeAccess(accessor->structure())) {
auto* vec_ty = access.type->As<type::Vector>();
- auto offset = Mul(ScalarSize(vec_ty->type()), swizzle[0]);
+ auto offset =
+ Mul(ScalarSize(vec_ty->type()), swizzle->Indices()[0]);
state.AddAccesss(
accessor, {
access.var,
@@ -672,7 +657,19 @@
vec_ty->type()->UnwrapAll(),
});
}
- break;
+ }
+ } else {
+ if (auto access = state.TakeAccess(accessor->structure())) {
+ auto* str_ty = access.type->As<type::Struct>();
+ auto* member =
+ sem.Get(str_ty)->FindMember(accessor->member()->symbol());
+ auto offset = member->Offset();
+ state.AddAccesss(accessor,
+ {
+ access.var,
+ Add(std::move(access.offset), std::move(offset)),
+ member->Declaration()->type()->UnwrapAll(),
+ });
}
}
continue;
diff --git a/src/transform/renamer.cc b/src/transform/renamer.cc
index 9277136..a22c785 100644
--- a/src/transform/renamer.cc
+++ b/src/transform/renamer.cc
@@ -53,7 +53,7 @@
<< "MemberAccessorExpression has no semantic info";
continue;
}
- if (sem->IsSwizzle()) {
+ if (sem->Is<semantic::Swizzle>()) {
preserve.emplace(member->member());
}
} else if (auto* call = node->As<ast::CallExpression>()) {
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 0fb6a12..9c1b179 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -2246,7 +2246,7 @@
out << ".";
// Swizzles output the name directly
- if (builder_.Sem().Get(expr)->IsSwizzle()) {
+ if (builder_.Sem().Get(expr)->Is<semantic::Swizzle>()) {
out << builder_.Symbols().NameFor(expr->member()->symbol());
} else if (!EmitExpression(pre, out, expr->member())) {
return false;
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index 98a2a85..5788eb9 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -1737,7 +1737,7 @@
out_ << ".";
// Swizzles get written out directly
- if (program_->Sem().Get(expr)->IsSwizzle()) {
+ if (program_->Sem().Get(expr)->Is<semantic::Swizzle>()) {
out_ << program_->Symbols().NameFor(expr->member()->symbol());
} else if (!EmitExpression(expr->member())) {
return false;