semantic: Add Function::Parameters(), MemberAccessorExpression subtypes

Add semantic::Swizzle and semantic::StructMemberAccess, both deriving from MemberAccessorExpression

Add semantic::Function::Parameters() to list the semantic::Variable parameters for the function.

Change-Id: I8cc69f3738380c14f61d051ee2989be6194d148d
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47220
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index c9110a4..06461e5 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -534,7 +534,9 @@
 
   variable_stack_.push_scope();
   for (auto* param : func->params()) {
-    variable_stack_.set(param->symbol(), CreateVariableInfo(param));
+    auto* param_info = CreateVariableInfo(param);
+    variable_stack_.set(param->symbol(), param_info);
+    func_info->parameters.emplace_back(param_info);
 
     if (!ApplyStorageClassUsageToType(param->declared_storage_class(),
                                       param->declared_type(),
@@ -1171,12 +1173,14 @@
   std::vector<uint32_t> swizzle;
 
   if (auto* ty = data_type->As<type::Struct>()) {
-    auto* strct = ty->impl();
+    auto* str = Structure(ty);
     auto symbol = expr->member()->symbol();
 
-    for (auto* member : strct->members()) {
-      if (member->symbol() == symbol) {
-        ret = member->type();
+    const semantic::StructMember* member = nullptr;
+    for (auto* m : str->members) {
+      if (m->Declaration()->symbol() == symbol) {
+        ret = m->Declaration()->type();
+        member = m;
         break;
       }
     }
@@ -1192,6 +1196,9 @@
     if (auto* ptr = res->As<type::Pointer>()) {
       ret = builder_->create<type::Pointer>(ret, ptr->storage_class());
     }
+
+    builder_->Sem().Add(expr, builder_->create<semantic::StructMemberAccess>(
+                                  expr, ret, current_statement_, member));
   } else if (auto* vec = data_type->As<type::Vector>()) {
     std::string str = builder_->Symbols().NameFor(expr->member()->symbol());
     auto size = str.size();
@@ -1257,6 +1264,9 @@
       ret = builder_->create<type::Vector>(vec->type(),
                                            static_cast<uint32_t>(size));
     }
+    builder_->Sem().Add(
+        expr, builder_->create<semantic::Swizzle>(expr, ret, current_statement_,
+                                                  std::move(swizzle)));
   } else {
     diagnostics_.add_error(
         "invalid use of member accessor on a non-vector/non-struct " +
@@ -1265,9 +1275,6 @@
     return false;
   }
 
-  builder_->Sem().Add(expr,
-                      builder_->create<semantic::MemberAccessorExpression>(
-                          expr, ret, current_statement_, std::move(swizzle)));
   SetType(expr, ret);
 
   return true;
@@ -1682,7 +1689,8 @@
     auto* info = it.second;
 
     auto* sem_func = builder_->create<semantic::Function>(
-        info->declaration, remap_vars(info->referenced_module_vars),
+        info->declaration, remap_vars(info->parameters),
+        remap_vars(info->referenced_module_vars),
         remap_vars(info->local_referenced_module_vars), info->return_statements,
         ancestor_entry_points[func->symbol()]);
     func_info_to_sem_func.emplace(info, sem_func);
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index 76c99c3..e7f1f76 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -106,6 +106,7 @@
     ~FunctionInfo();
 
     ast::Function* const declaration;
+    std::vector<VariableInfo*> parameters;
     UniqueVector<VariableInfo*> referenced_module_vars;
     UniqueVector<VariableInfo*> local_referenced_module_vars;
     std::vector<const ast::ReturnStatement*> return_statements;
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index 470adab..389fcec 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -736,6 +736,32 @@
   EXPECT_FALSE(r()->Resolve());
 }
 
+TEST_F(ResolverTest, Function_Parameters) {
+  auto* param_a = Param("a", ty.f32());
+  auto* param_b = Param("b", ty.i32());
+  auto* param_c = Param("c", ty.u32());
+
+  auto* func = Func("my_func",
+                    ast::VariableList{
+                        param_a,
+                        param_b,
+                        param_c,
+                    },
+                    ty.void_(), {});
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* func_sem = Sem().Get(func);
+  ASSERT_NE(func_sem, nullptr);
+  EXPECT_EQ(func_sem->Parameters().size(), 3u);
+  EXPECT_EQ(func_sem->Parameters()[0]->Type(), ty.f32());
+  EXPECT_EQ(func_sem->Parameters()[1]->Type(), ty.i32());
+  EXPECT_EQ(func_sem->Parameters()[2]->Type(), ty.u32());
+  EXPECT_EQ(func_sem->Parameters()[0]->Declaration(), param_a);
+  EXPECT_EQ(func_sem->Parameters()[1]->Declaration(), param_b);
+  EXPECT_EQ(func_sem->Parameters()[2]->Declaration(), param_c);
+}
+
 TEST_F(ResolverTest, Function_RegisterInputOutputVariables) {
   auto* in_var = Global("in_var", ty.f32(), ast::StorageClass::kInput);
   auto* out_var = Global("out_var", ty.f32(), ast::StorageClass::kOutput);
@@ -757,6 +783,7 @@
 
   auto* func_sem = Sem().Get(func);
   ASSERT_NE(func_sem, nullptr);
+  EXPECT_EQ(func_sem->Parameters().size(), 0u);
 
   const auto& vars = func_sem->ReferencedModuleVariables();
   ASSERT_EQ(vars.size(), 5u);
@@ -794,6 +821,7 @@
 
   auto* func2_sem = Sem().Get(func2);
   ASSERT_NE(func2_sem, nullptr);
+  EXPECT_EQ(func2_sem->Parameters().size(), 0u);
 
   const auto& vars = func2_sem->ReferencedModuleVariables();
   ASSERT_EQ(vars.size(), 5u);
@@ -842,6 +870,7 @@
 
   auto* func_sem = Sem().Get(func);
   ASSERT_NE(func_sem, nullptr);
+  EXPECT_EQ(func_sem->Parameters().size(), 0u);
 
   EXPECT_EQ(func_sem->ReturnStatements().size(), 2u);
   EXPECT_EQ(func_sem->ReturnStatements()[0], ret_1);
@@ -867,6 +896,14 @@
 
   auto* ptr = TypeOf(mem)->As<type::Pointer>();
   EXPECT_TRUE(ptr->type()->Is<type::F32>());
+  ASSERT_TRUE(Sem().Get(mem)->Is<semantic::StructMemberAccess>());
+  EXPECT_EQ(Sem()
+                .Get(mem)
+                ->As<semantic::StructMemberAccess>()
+                ->Member()
+                ->Declaration()
+                ->symbol(),
+            Symbols().Get("second_member"));
 }
 
 TEST_F(ResolverTest, Expr_MemberAccessor_Struct_Alias) {
@@ -889,6 +926,7 @@
 
   auto* ptr = TypeOf(mem)->As<type::Pointer>();
   EXPECT_TRUE(ptr->type()->Is<type::F32>());
+  ASSERT_TRUE(Sem().Get(mem)->Is<semantic::StructMemberAccess>());
 }
 
 TEST_F(ResolverTest, Expr_MemberAccessor_VectorSwizzle) {
@@ -903,7 +941,9 @@
   ASSERT_TRUE(TypeOf(mem)->Is<type::Vector>());
   EXPECT_TRUE(TypeOf(mem)->As<type::Vector>()->type()->Is<type::F32>());
   EXPECT_EQ(TypeOf(mem)->As<type::Vector>()->size(), 4u);
-  EXPECT_THAT(Sem().Get(mem)->Swizzle(), ElementsAre(0, 2, 1, 3));
+  ASSERT_TRUE(Sem().Get(mem)->Is<semantic::Swizzle>());
+  EXPECT_THAT(Sem().Get(mem)->As<semantic::Swizzle>()->Indices(),
+              ElementsAre(0, 2, 1, 3));
 }
 
 TEST_F(ResolverTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
@@ -919,7 +959,9 @@
 
   auto* ptr = TypeOf(mem)->As<type::Pointer>();
   ASSERT_TRUE(ptr->type()->Is<type::F32>());
-  EXPECT_THAT(Sem().Get(mem)->Swizzle(), ElementsAre(2));
+  ASSERT_TRUE(Sem().Get(mem)->Is<semantic::Swizzle>());
+  EXPECT_THAT(Sem().Get(mem)->As<semantic::Swizzle>()->Indices(),
+              ElementsAre(2));
 }
 
 TEST_F(ResolverTest, Expr_Accessor_MultiLevel) {
@@ -971,6 +1013,7 @@
   ASSERT_TRUE(TypeOf(mem)->Is<type::Vector>());
   EXPECT_TRUE(TypeOf(mem)->As<type::Vector>()->type()->Is<type::F32>());
   EXPECT_EQ(TypeOf(mem)->As<type::Vector>()->size(), 2u);
+  ASSERT_TRUE(Sem().Get(mem)->Is<semantic::Swizzle>());
 }
 
 TEST_F(ResolverTest, Expr_MemberAccessor_InBinaryOp) {
@@ -1502,6 +1545,10 @@
   ASSERT_NE(ep_1_sem, nullptr);
   ASSERT_NE(ep_2_sem, nullptr);
 
+  EXPECT_EQ(func_b_sem->Parameters().size(), 0u);
+  EXPECT_EQ(func_a_sem->Parameters().size(), 0u);
+  EXPECT_EQ(func_c_sem->Parameters().size(), 0u);
+
   const auto& b_eps = func_b_sem->AncestorEntryPoints();
   ASSERT_EQ(2u, b_eps.size());
   EXPECT_EQ(Symbols().Register("ep_1"), b_eps[0]);
diff --git a/src/semantic/function.h b/src/semantic/function.h
index efc6fbc..6301374 100644
--- a/src/semantic/function.h
+++ b/src/semantic/function.h
@@ -46,12 +46,14 @@
 
   /// Constructor
   /// @param declaration the ast::Function
+  /// @param parameters the parameters to the function
   /// @param referenced_module_vars the referenced module variables
   /// @param local_referenced_module_vars the locally referenced module
   /// @param return_statements the function return statements
   /// variables
   /// @param ancestor_entry_points the ancestor entry points
   Function(ast::Function* declaration,
+           std::vector<const Variable*> parameters,
            std::vector<const Variable*> referenced_module_vars,
            std::vector<const Variable*> local_referenced_module_vars,
            std::vector<const ast::ReturnStatement*> return_statements,
@@ -63,6 +65,9 @@
   /// @returns the ast::Function declaration
   ast::Function* Declaration() const { return declaration_; }
 
+  /// @return the parameters to the function
+  const std::vector<const Variable*> Parameters() const { return parameters_; }
+
   /// Note: If this function calls other functions, the return will also include
   /// all of the referenced variables from the callees.
   /// @returns the referenced module variables
@@ -147,6 +152,7 @@
       bool multisampled) const;
 
   ast::Function* const declaration_;
+  std::vector<const Variable*> const parameters_;
   std::vector<const Variable*> const referenced_module_vars_;
   std::vector<const Variable*> const local_referenced_module_vars_;
   std::vector<const ast::ReturnStatement*> const return_statements_;
diff --git a/src/semantic/member_accessor_expression.h b/src/semantic/member_accessor_expression.h
index 71268b4..4865b77 100644
--- a/src/semantic/member_accessor_expression.h
+++ b/src/semantic/member_accessor_expression.h
@@ -20,8 +20,18 @@
 #include "src/semantic/expression.h"
 
 namespace tint {
+
+/// Forward declarations
+namespace ast {
+class MemberAccessorExpression;
+}  // namespace ast
+
 namespace semantic {
 
+/// Forward declarations
+class Struct;
+class StructMember;
+
 /// MemberAccessorExpression holds the semantic information for a
 /// ast::MemberAccessorExpression node.
 class MemberAccessorExpression
@@ -31,24 +41,60 @@
   /// @param declaration the AST node
   /// @param type the resolved type of the expression
   /// @param statement the statement that owns this expression
-  /// @param swizzle if this member access is for a vector swizzle, the swizzle
-  /// indices
-  MemberAccessorExpression(ast::Expression* declaration,
+  MemberAccessorExpression(ast::MemberAccessorExpression* declaration,
                            type::Type* type,
-                           Statement* statement,
-                           std::vector<uint32_t> swizzle);
+                           Statement* statement);
 
   /// Destructor
   ~MemberAccessorExpression() override;
+};
 
-  /// @return true if this member access is for a vector swizzle
-  bool IsSwizzle() const { return !swizzle_.empty(); }
+/// StructMemberAccess holds the semantic information for a
+/// ast::MemberAccessorExpression node that represents an access to a structure
+/// member.
+class StructMemberAccess
+    : public Castable<StructMemberAccess, MemberAccessorExpression> {
+ public:
+  /// Constructor
+  /// @param declaration the AST node
+  /// @param type the resolved type of the expression
+  /// @param member the structure member
+  StructMemberAccess(ast::MemberAccessorExpression* declaration,
+                     type::Type* type,
+                     Statement* statement,
+                     const StructMember* member);
 
-  /// @return the swizzle indices, if this is a vector swizzle
-  const std::vector<uint32_t>& Swizzle() const { return swizzle_; }
+  /// Destructor
+  ~StructMemberAccess() override;
+
+  /// @returns the structure member
+  StructMember const* Member() const { return member_; }
 
  private:
-  std::vector<uint32_t> const swizzle_;
+  StructMember const* const member_;
+};
+
+/// Swizzle holds the semantic information for a ast::MemberAccessorExpression
+/// node that represents a vector swizzle.
+class Swizzle : public Castable<Swizzle, MemberAccessorExpression> {
+ public:
+  /// Constructor
+  /// @param declaration the AST node
+  /// @param type the resolved type of the expression
+  /// @param indices the swizzle indices
+  Swizzle(ast::MemberAccessorExpression* declaration,
+          type::Type* type,
+          Statement* statement,
+          std::vector<uint32_t> indices);
+
+  /// Destructor
+  ~Swizzle() override;
+
+  /// @return the swizzle indices, if this is a vector swizzle
+  const std::vector<uint32_t>& Indices() const { return indices_; }
+
+ private:
+  std::vector<uint32_t> const indices_;
 };
 
 }  // namespace semantic
diff --git a/src/semantic/sem_function.cc b/src/semantic/sem_function.cc
index dab7f1c..97e94af 100644
--- a/src/semantic/sem_function.cc
+++ b/src/semantic/sem_function.cc
@@ -41,12 +41,14 @@
 }  // namespace
 
 Function::Function(ast::Function* declaration,
+                   std::vector<const Variable*> parameters,
                    std::vector<const Variable*> referenced_module_vars,
                    std::vector<const Variable*> local_referenced_module_vars,
                    std::vector<const ast::ReturnStatement*> return_statements,
                    std::vector<Symbol> ancestor_entry_points)
     : Base(declaration->return_type(), GetParameters(declaration)),
       declaration_(declaration),
+      parameters_(std::move(parameters)),
       referenced_module_vars_(std::move(referenced_module_vars)),
       local_referenced_module_vars_(std::move(local_referenced_module_vars)),
       return_statements_(std::move(return_statements)),
diff --git a/src/semantic/sem_member_accessor_expression.cc b/src/semantic/sem_member_accessor_expression.cc
index 00488cc..d0d4968 100644
--- a/src/semantic/sem_member_accessor_expression.cc
+++ b/src/semantic/sem_member_accessor_expression.cc
@@ -12,21 +12,40 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+#include "src/ast/member_accessor_expression.h"
 #include "src/semantic/member_accessor_expression.h"
 
 TINT_INSTANTIATE_TYPEINFO(tint::semantic::MemberAccessorExpression);
+TINT_INSTANTIATE_TYPEINFO(tint::semantic::StructMemberAccess);
+TINT_INSTANTIATE_TYPEINFO(tint::semantic::Swizzle);
 
 namespace tint {
 namespace semantic {
 
 MemberAccessorExpression::MemberAccessorExpression(
-    ast::Expression* declaration,
+    ast::MemberAccessorExpression* declaration,
     type::Type* type,
-    Statement* statement,
-    std::vector<uint32_t> swizzle)
-    : Base(declaration, type, statement), swizzle_(std::move(swizzle)) {}
+    Statement* statement)
+    : Base(declaration, type, statement) {}
 
 MemberAccessorExpression::~MemberAccessorExpression() = default;
 
+StructMemberAccess::StructMemberAccess(
+    ast::MemberAccessorExpression* declaration,
+    type::Type* type,
+    Statement* statement,
+    const StructMember* member)
+    : Base(declaration, type, statement), member_(member) {}
+
+StructMemberAccess::~StructMemberAccess() = default;
+
+Swizzle::Swizzle(ast::MemberAccessorExpression* declaration,
+                 type::Type* type,
+                 Statement* statement,
+                 std::vector<uint32_t> indices)
+    : Base(declaration, type, statement), indices_(std::move(indices)) {}
+
+Swizzle::~Swizzle() = default;
+
 }  // namespace semantic
 }  // namespace tint
diff --git a/src/transform/decompose_storage_access.cc b/src/transform/decompose_storage_access.cc
index 7b1a639..dc38ba4 100644
--- a/src/transform/decompose_storage_access.cc
+++ b/src/transform/decompose_storage_access.cc
@@ -644,27 +644,12 @@
     if (auto* accessor = node->As<ast::MemberAccessorExpression>()) {
       // X.Y
       auto* accessor_sem = sem.Get(accessor);
-      auto swizzle = accessor_sem->Swizzle();
-      switch (swizzle.size()) {
-        case 0: {
-          if (auto access = state.TakeAccess(accessor->structure())) {
-            auto* str_ty = access.type->As<type::Struct>();
-            auto* member =
-                sem.Get(str_ty)->FindMember(accessor->member()->symbol());
-            auto offset = member->Offset();
-            state.AddAccesss(
-                accessor, {
-                              access.var,
-                              Add(std::move(access.offset), std::move(offset)),
-                              member->Declaration()->type()->UnwrapAll(),
-                          });
-          }
-          break;
-        }
-        case 1: {
+      if (auto* swizzle = accessor_sem->As<semantic::Swizzle>()) {
+        if (swizzle->Indices().size() == 1) {
           if (auto access = state.TakeAccess(accessor->structure())) {
             auto* vec_ty = access.type->As<type::Vector>();
-            auto offset = Mul(ScalarSize(vec_ty->type()), swizzle[0]);
+            auto offset =
+                Mul(ScalarSize(vec_ty->type()), swizzle->Indices()[0]);
             state.AddAccesss(
                 accessor, {
                               access.var,
@@ -672,7 +657,19 @@
                               vec_ty->type()->UnwrapAll(),
                           });
           }
-          break;
+        }
+      } else {
+        if (auto access = state.TakeAccess(accessor->structure())) {
+          auto* str_ty = access.type->As<type::Struct>();
+          auto* member =
+              sem.Get(str_ty)->FindMember(accessor->member()->symbol());
+          auto offset = member->Offset();
+          state.AddAccesss(accessor,
+                           {
+                               access.var,
+                               Add(std::move(access.offset), std::move(offset)),
+                               member->Declaration()->type()->UnwrapAll(),
+                           });
         }
       }
       continue;
diff --git a/src/transform/renamer.cc b/src/transform/renamer.cc
index 9277136..a22c785 100644
--- a/src/transform/renamer.cc
+++ b/src/transform/renamer.cc
@@ -53,7 +53,7 @@
             << "MemberAccessorExpression has no semantic info";
         continue;
       }
-      if (sem->IsSwizzle()) {
+      if (sem->Is<semantic::Swizzle>()) {
         preserve.emplace(member->member());
       }
     } else if (auto* call = node->As<ast::CallExpression>()) {
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 0fb6a12..9c1b179 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -2246,7 +2246,7 @@
   out << ".";
 
   // Swizzles output the name directly
-  if (builder_.Sem().Get(expr)->IsSwizzle()) {
+  if (builder_.Sem().Get(expr)->Is<semantic::Swizzle>()) {
     out << builder_.Symbols().NameFor(expr->member()->symbol());
   } else if (!EmitExpression(pre, out, expr->member())) {
     return false;
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index 98a2a85..5788eb9 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -1737,7 +1737,7 @@
   out_ << ".";
 
   // Swizzles get written out directly
-  if (program_->Sem().Get(expr)->IsSwizzle()) {
+  if (program_->Sem().Get(expr)->Is<semantic::Swizzle>()) {
     out_ << program_->Symbols().NameFor(expr->member()->symbol());
   } else if (!EmitExpression(expr->member())) {
     return false;