Add type determination for member accessor.

This Cl adds the member accessor type determination for both structures
and vector swizzles.

Bug: tint:5
Change-Id: I1172db29d8cbed2d9e0ae228ebc3a818d4930b7f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/18846
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/ast/relational_expression.h b/src/ast/relational_expression.h
index 04b9064..f5bd944 100644
--- a/src/ast/relational_expression.h
+++ b/src/ast/relational_expression.h
@@ -48,7 +48,7 @@
   kModulo,
 };
 
-/// A Relational Expression
+/// An xor expression
 class RelationalExpression : public Expression {
  public:
   /// Constructor
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index f1da09c..9a81e19 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -28,12 +28,14 @@
 #include "src/ast/identifier_expression.h"
 #include "src/ast/if_statement.h"
 #include "src/ast/loop_statement.h"
+#include "src/ast/member_accessor_expression.h"
 #include "src/ast/regardless_statement.h"
 #include "src/ast/return_statement.h"
 #include "src/ast/scalar_constructor_expression.h"
 #include "src/ast/switch_statement.h"
 #include "src/ast/type/array_type.h"
 #include "src/ast/type/matrix_type.h"
+#include "src/ast/type/struct_type.h"
 #include "src/ast/type/vector_type.h"
 #include "src/ast/type_constructor_expression.h"
 #include "src/ast/unless_statement.h"
@@ -41,7 +43,10 @@
 
 namespace tint {
 
-TypeDeterminer::TypeDeterminer(Context* ctx) : ctx_(*ctx) {}
+TypeDeterminer::TypeDeterminer(Context* ctx) : ctx_(*ctx) {
+  // TODO(dsinclair): Temporary usage to avoid compiler warning
+  static_cast<void>(ctx_.type_mgr());
+}
 
 TypeDeterminer::~TypeDeterminer() = default;
 
@@ -174,15 +179,6 @@
   return false;
 }
 
-bool TypeDeterminer::DetermineResultType(const ast::ExpressionList& exprs) {
-  for (const auto& expr : exprs) {
-    if (!DetermineResultType(expr.get())) {
-      return false;
-    }
-  }
-  return true;
-}
-
 bool TypeDeterminer::DetermineResultType(ast::Expression* expr) {
   // This is blindly called above, so in some cases the expression won't exist.
   if (!expr) {
@@ -207,6 +203,9 @@
   if (expr->IsIdentifier()) {
     return DetermineIdentifier(expr->AsIdentifier());
   }
+  if (expr->IsMemberAccessor()) {
+    return DetermineMemberAccessor(expr->AsMemberAccessor());
+  }
 
   error_ = "unknown expression for type determination";
   return false;
@@ -242,9 +241,6 @@
   if (!DetermineResultType(expr->func())) {
     return false;
   }
-  if (!DetermineResultType(expr->params())) {
-    return false;
-  }
   expr->set_result_type(expr->func()->result_type());
   return true;
 }
@@ -283,7 +279,45 @@
     return true;
   }
 
-  error_ = "unknown identifier for type determination";
+  return true;
+}
+
+bool TypeDeterminer::DetermineMemberAccessor(
+    ast::MemberAccessorExpression* expr) {
+  if (!DetermineResultType(expr->structure())) {
+    return false;
+  }
+
+  auto data_type = expr->structure()->result_type();
+  if (data_type->IsStruct()) {
+    auto strct = data_type->AsStruct()->impl();
+    auto name = expr->member()->name()[0];
+
+    for (const auto& member : strct->members()) {
+      if (member->name() != name) {
+        continue;
+      }
+
+      expr->set_result_type(member->type());
+      return true;
+    }
+
+    error_ = "struct member not found";
+    return false;
+  }
+  if (data_type->IsVector()) {
+    auto vec = data_type->AsVector();
+
+    // The vector will have a number of components equal to the length of the
+    // swizzle. This assumes the validator will check that the swizzle
+    // is correct.
+    expr->set_result_type(
+        ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
+            vec->type(), expr->member()->name()[0].size())));
+    return true;
+  }
+
+  error_ = "invalid type in member accessor";
   return false;
 }
 
diff --git a/src/type_determiner.h b/src/type_determiner.h
index 692947e..34fcaf3 100644
--- a/src/type_determiner.h
+++ b/src/type_determiner.h
@@ -31,6 +31,7 @@
 class CastExpression;
 class ConstructorExpression;
 class IdentifierExpression;
+class MemberAccessorExpression;
 class Function;
 class Variable;
 
@@ -67,10 +68,6 @@
   /// @param stmt the statement to check
   /// @returns true if the determination was successful
   bool DetermineResultType(ast::Statement* stmt);
-  /// Determines type information for a list of expressions
-  /// @param exprs the expressions to check
-  /// @returns true if the determination was successful
-  bool DetermineResultType(const ast::ExpressionList& exprs);
   /// Determines type information for an expression
   /// @param expr the expression to check
   /// @returns true if the determination was successful
@@ -83,6 +80,8 @@
   bool DetermineCast(ast::CastExpression* expr);
   bool DetermineConstructor(ast::ConstructorExpression* expr);
   bool DetermineIdentifier(ast::IdentifierExpression* expr);
+  bool DetermineMemberAccessor(ast::MemberAccessorExpression* expr);
+
   Context& ctx_;
   std::string error_;
   ScopeStack<ast::Variable*> variable_stack_;
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index 7796ca7..486246b 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -1,3 +1,4 @@
+
 // Copyright 2020 The Tint Authors.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
@@ -32,14 +33,18 @@
 #include "src/ast/if_statement.h"
 #include "src/ast/int_literal.h"
 #include "src/ast/loop_statement.h"
+#include "src/ast/member_accessor_expression.h"
 #include "src/ast/regardless_statement.h"
 #include "src/ast/return_statement.h"
 #include "src/ast/scalar_constructor_expression.h"
+#include "src/ast/struct.h"
+#include "src/ast/struct_member.h"
 #include "src/ast/switch_statement.h"
 #include "src/ast/type/array_type.h"
 #include "src/ast/type/f32_type.h"
 #include "src/ast/type/i32_type.h"
 #include "src/ast/type/matrix_type.h"
+#include "src/ast/type/struct_type.h"
 #include "src/ast/type/vector_type.h"
 #include "src/ast/type_constructor_expression.h"
 #include "src/ast/unless_statement.h"
@@ -512,45 +517,6 @@
   EXPECT_TRUE(call.result_type()->IsF32());
 }
 
-TEST_F(TypeDeterminerTest, Expr_Call_WithParams) {
-  ast::type::F32Type f32;
-  ast::type::I32Type i32;
-
-  ast::VariableList params;
-  params.push_back(
-      std::make_unique<ast::Variable>("a", ast::StorageClass::kNone, &f32));
-  params.push_back(
-      std::make_unique<ast::Variable>("b", ast::StorageClass::kNone, &i32));
-
-  auto func =
-      std::make_unique<ast::Function>("my_func", std::move(params), &f32);
-  ast::Module m;
-  m.AddFunction(std::move(func));
-
-  // Register the function
-  EXPECT_TRUE(td()->Determine(&m));
-
-  ast::ExpressionList call_params;
-  call_params.push_back(std::make_unique<ast::ScalarConstructorExpression>(
-      std::make_unique<ast::FloatLiteral>(&f32, 2.5f)));
-  auto a_ptr = call_params.back().get();
-  call_params.push_back(std::make_unique<ast::ScalarConstructorExpression>(
-      std::make_unique<ast::IntLiteral>(&i32, 1)));
-  auto b_ptr = call_params.back().get();
-
-  ast::CallExpression call(
-      std::make_unique<ast::IdentifierExpression>("my_func"),
-      std::move(call_params));
-  EXPECT_TRUE(td()->DetermineResultType(&call));
-  ASSERT_NE(call.result_type(), nullptr);
-  EXPECT_TRUE(call.result_type()->IsF32());
-
-  ASSERT_NE(a_ptr->result_type(), nullptr);
-  EXPECT_TRUE(a_ptr->result_type()->IsF32());
-  ASSERT_NE(b_ptr->result_type(), nullptr);
-  EXPECT_TRUE(b_ptr->result_type()->IsI32());
-}
-
 TEST_F(TypeDeterminerTest, Expr_Cast) {
   ast::type::F32Type f32;
   ast::CastExpression cast(&f32,
@@ -651,5 +617,143 @@
   EXPECT_TRUE(ident.result_type()->IsF32());
 }
 
+TEST_F(TypeDeterminerTest, Expr_MemberAccessor_Struct) {
+  ast::type::I32Type i32;
+  ast::type::F32Type f32;
+
+  ast::StructMemberDecorationList decos;
+  ast::StructMemberList members;
+  members.push_back(std::make_unique<ast::StructMember>("first_member", &i32,
+                                                        std::move(decos)));
+  members.push_back(std::make_unique<ast::StructMember>("second_member", &f32,
+                                                        std::move(decos)));
+
+  auto strct = std::make_unique<ast::Struct>(ast::StructDecoration::kNone,
+                                             std::move(members));
+
+  ast::type::StructType st(std::move(strct));
+
+  auto var = std::make_unique<ast::Variable>("my_struct",
+                                             ast::StorageClass::kNone, &st);
+
+  ast::Module m;
+  m.AddGlobalVariable(std::move(var));
+
+  // Register the global
+  EXPECT_TRUE(td()->Determine(&m));
+
+  auto ident = std::make_unique<ast::IdentifierExpression>("my_struct");
+  auto mem_ident = std::make_unique<ast::IdentifierExpression>("second_member");
+
+  ast::MemberAccessorExpression mem(std::move(ident), std::move(mem_ident));
+  EXPECT_TRUE(td()->DetermineResultType(&mem));
+  ASSERT_NE(mem.result_type(), nullptr);
+  EXPECT_TRUE(mem.result_type()->IsF32());
+}
+
+TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) {
+  ast::type::F32Type f32;
+  ast::type::VectorType vec3(&f32, 3);
+
+  auto var = std::make_unique<ast::Variable>("my_vec", ast::StorageClass::kNone,
+                                             &vec3);
+  ast::Module m;
+  m.AddGlobalVariable(std::move(var));
+
+  // Register the global
+  EXPECT_TRUE(td()->Determine(&m));
+
+  auto ident = std::make_unique<ast::IdentifierExpression>("my_vec");
+  auto swizzle = std::make_unique<ast::IdentifierExpression>("xy");
+
+  ast::MemberAccessorExpression mem(std::move(ident), std::move(swizzle));
+  EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error();
+  ASSERT_NE(mem.result_type(), nullptr);
+  ASSERT_TRUE(mem.result_type()->IsVector());
+  EXPECT_TRUE(mem.result_type()->AsVector()->type()->IsF32());
+  EXPECT_EQ(mem.result_type()->AsVector()->size(), 2);
+}
+
+TEST_F(TypeDeterminerTest, Expr_MultiLevel) {
+  // struct b {
+  //   vec4<f32> foo
+  // }
+  // struct A {
+  //   vec3<struct b> mem
+  // }
+  // var c : A
+  // c.mem[0].foo.yx
+  //   -> vec2<f32>
+  //
+  // MemberAccessor{
+  //   MemberAccessor{
+  //     ArrayAccessor{
+  //       MemberAccessor{
+  //         Identifier{c}
+  //         Identifier{mem}
+  //       }
+  //       ScalarConstructor{0}
+  //     }
+  //     Identifier{foo}
+  //   }
+  //   Identifier{yx}
+  // }
+  //
+  ast::type::I32Type i32;
+  ast::type::F32Type f32;
+
+  ast::type::VectorType vec4(&f32, 4);
+
+  ast::StructMemberDecorationList decos;
+  ast::StructMemberList b_members;
+  b_members.push_back(
+      std::make_unique<ast::StructMember>("foo", &vec4, std::move(decos)));
+
+  auto strctB = std::make_unique<ast::Struct>(ast::StructDecoration::kNone,
+                                              std::move(b_members));
+  ast::type::StructType stB(std::move(strctB));
+
+  ast::type::VectorType vecB(&stB, 3);
+
+  ast::StructMemberList a_members;
+  a_members.push_back(
+      std::make_unique<ast::StructMember>("mem", &vecB, std::move(decos)));
+
+  auto strctA = std::make_unique<ast::Struct>(ast::StructDecoration::kNone,
+                                              std::move(a_members));
+
+  ast::type::StructType stA(std::move(strctA));
+
+  auto var =
+      std::make_unique<ast::Variable>("c", ast::StorageClass::kNone, &stA);
+
+  ast::Module m;
+  m.AddGlobalVariable(std::move(var));
+
+  // Register the global
+  EXPECT_TRUE(td()->Determine(&m));
+
+  auto ident = std::make_unique<ast::IdentifierExpression>("c");
+  auto mem_ident = std::make_unique<ast::IdentifierExpression>("mem");
+  auto foo_ident = std::make_unique<ast::IdentifierExpression>("foo");
+  auto idx = std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::IntLiteral>(&i32, 0));
+  auto swizzle = std::make_unique<ast::IdentifierExpression>("yx");
+
+  ast::MemberAccessorExpression mem(
+      std::make_unique<ast::MemberAccessorExpression>(
+          std::make_unique<ast::ArrayAccessorExpression>(
+              std::make_unique<ast::MemberAccessorExpression>(
+                  std::move(ident), std::move(mem_ident)),
+              std::move(idx)),
+          std::move(foo_ident)),
+      std::move(swizzle));
+  EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error();
+  ASSERT_NE(mem.result_type(), nullptr);
+  ASSERT_TRUE(mem.result_type()->IsVector());
+  EXPECT_TRUE(mem.result_type()->AsVector()->type()->IsF32());
+  EXPECT_EQ(mem.result_type()->AsVector()->size(), 2);
+}
+
 }  // namespace
 }  // namespace tint