[semantic] Add semantic::Variable::Users()

Returns a list of ast::IdentifierExpression* nodes that reference the
variable.

Change-Id: I36f475c6ddf5482f9ae9b432190405625f379f0d
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/41661
Commit-Queue: James Price <jrprice@google.com>
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/semantic/call.h b/src/semantic/call.h
index a4931fb..954585c 100644
--- a/src/semantic/call.h
+++ b/src/semantic/call.h
@@ -26,9 +26,12 @@
 class Call : public Castable<Call, Expression> {
  public:
   /// Constructor
+  /// @param declaration the AST node
   /// @param target the call target
   /// @param statement the statement that owns this expression
-  explicit Call(const CallTarget* target, Statement* statement);
+  explicit Call(ast::Expression* declaration,
+                const CallTarget* target,
+                Statement* statement);
 
   /// Destructor
   ~Call() override;
diff --git a/src/semantic/expression.h b/src/semantic/expression.h
index 57ed61c..6b9b4b1 100644
--- a/src/semantic/expression.h
+++ b/src/semantic/expression.h
@@ -15,6 +15,7 @@
 #ifndef SRC_SEMANTIC_EXPRESSION_H_
 #define SRC_SEMANTIC_EXPRESSION_H_
 
+#include "src/ast/expression.h"
 #include "src/semantic/node.h"
 
 namespace tint {
@@ -33,9 +34,12 @@
 class Expression : public Castable<Expression, Node> {
  public:
   /// Constructor
+  /// @param declaration the AST node
   /// @param type the resolved type of the expression
   /// @param statement the statement that owns this expression
-  explicit Expression(type::Type* type, Statement* statement);
+  explicit Expression(ast::Expression* declaration,
+                      type::Type* type,
+                      Statement* statement);
 
   /// @return the resolved type of the expression
   type::Type* Type() const { return type_; }
@@ -43,7 +47,11 @@
   /// @return the statement that owns this expression
   Statement* Stmt() const { return statement_; }
 
+  /// @returns the AST node
+  ast::Expression* Declaration() const { return declaration_; }
+
  private:
+  ast::Expression* declaration_;
   type::Type* const type_;
   Statement* const statement_;
 };
diff --git a/src/semantic/member_accessor_expression.h b/src/semantic/member_accessor_expression.h
index f46fe6e..4a7d7f6 100644
--- a/src/semantic/member_accessor_expression.h
+++ b/src/semantic/member_accessor_expression.h
@@ -26,10 +26,12 @@
     : public Castable<MemberAccessorExpression, Expression> {
  public:
   /// Constructor
+  /// @param declaration the AST node
   /// @param type the resolved type of the expression
   /// @param statement the statement that owns this expression
   /// @param is_swizzle true if this member access is for a vector swizzle
-  MemberAccessorExpression(type::Type* type,
+  MemberAccessorExpression(ast::Expression* declaration,
+                           type::Type* type,
                            Statement* statement,
                            bool is_swizzle);
 
diff --git a/src/semantic/sem_call.cc b/src/semantic/sem_call.cc
index 8f0a41f..6ba47ed 100644
--- a/src/semantic/sem_call.cc
+++ b/src/semantic/sem_call.cc
@@ -19,8 +19,10 @@
 namespace tint {
 namespace semantic {
 
-Call::Call(const CallTarget* target, Statement* statement)
-    : Base(target->ReturnType(), statement), target_(target) {}
+Call::Call(ast::Expression* declaration,
+           const CallTarget* target,
+           Statement* statement)
+    : Base(declaration, target->ReturnType(), statement), target_(target) {}
 
 Call::~Call() = default;
 
diff --git a/src/semantic/sem_expression.cc b/src/semantic/sem_expression.cc
index 36cc8c4..285493e 100644
--- a/src/semantic/sem_expression.cc
+++ b/src/semantic/sem_expression.cc
@@ -21,8 +21,12 @@
 namespace tint {
 namespace semantic {
 
-Expression::Expression(type::Type* type, Statement* statement)
-    : type_(type->UnwrapIfNeeded()), statement_(statement) {}
+Expression::Expression(ast::Expression* declaration,
+                       type::Type* type,
+                       Statement* statement)
+    : declaration_(declaration),
+      type_(type->UnwrapIfNeeded()),
+      statement_(statement) {}
 
 }  // namespace semantic
 }  // namespace tint
diff --git a/src/semantic/sem_member_accessor_expression.cc b/src/semantic/sem_member_accessor_expression.cc
index e85c7bf..470f059 100644
--- a/src/semantic/sem_member_accessor_expression.cc
+++ b/src/semantic/sem_member_accessor_expression.cc
@@ -19,10 +19,11 @@
 namespace tint {
 namespace semantic {
 
-MemberAccessorExpression::MemberAccessorExpression(type::Type* type,
+MemberAccessorExpression::MemberAccessorExpression(ast::Expression* declaration,
+                                                   type::Type* type,
                                                    Statement* statement,
                                                    bool is_swizzle)
-    : Base(type, statement), is_swizzle_(is_swizzle) {}
+    : Base(declaration, type, statement), is_swizzle_(is_swizzle) {}
 
 }  // namespace semantic
 }  // namespace tint
diff --git a/src/semantic/sem_variable.cc b/src/semantic/sem_variable.cc
index 33e84a1..8282087 100644
--- a/src/semantic/sem_variable.cc
+++ b/src/semantic/sem_variable.cc
@@ -19,8 +19,12 @@
 namespace tint {
 namespace semantic {
 
-Variable::Variable(ast::Variable* declaration, ast::StorageClass storage_class)
-    : declaration_(declaration), storage_class_(storage_class) {}
+Variable::Variable(ast::Variable* declaration,
+                   ast::StorageClass storage_class,
+                   std::vector<const Expression*> users)
+    : declaration_(declaration),
+      storage_class_(storage_class),
+      users_(std::move(users)) {}
 
 Variable::~Variable() = default;
 
diff --git a/src/semantic/variable.h b/src/semantic/variable.h
index 9032fe9..3808b89 100644
--- a/src/semantic/variable.h
+++ b/src/semantic/variable.h
@@ -19,6 +19,7 @@
 #include <vector>
 
 #include "src/ast/storage_class.h"
+#include "src/semantic/expression.h"
 #include "src/semantic/node.h"
 #include "src/type/sampler_type.h"
 
@@ -40,8 +41,10 @@
   /// Constructor
   /// @param declaration the AST declaration node
   /// @param storage_class the variable storage class
+  /// @param users the expressions that use the variable
   explicit Variable(ast::Variable* declaration,
-                    ast::StorageClass storage_class);
+                    ast::StorageClass storage_class,
+                    std::vector<const Expression*> users);
 
   /// Destructor
   ~Variable() override;
@@ -52,9 +55,13 @@
   /// @returns the storage class for the variable
   ast::StorageClass StorageClass() const { return storage_class_; }
 
+  /// @returns the expressions that use the variable
+  const std::vector<const Expression*>& Users() const { return users_; }
+
  private:
   ast::Variable* const declaration_;
   ast::StorageClass const storage_class_;
+  std::vector<const Expression*> const users_;
 };
 
 }  // namespace semantic
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index 1bb2d4c..a72c3ab 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -524,14 +524,14 @@
     }
     auto* intrinsic = builder_->create<semantic::Intrinsic>(intrinsic_type,
                                                             ret_ty, parameters);
-    builder_->Sem().Add(
-        call, builder_->create<semantic::Call>(intrinsic, current_statement_));
+    builder_->Sem().Add(call, builder_->create<semantic::Call>(
+                                  call, intrinsic, current_statement_));
     SetType(call, ret_ty);
     return false;
   }
 
   builder_->Sem().Add(call, builder_->create<semantic::Call>(
-                                result.intrinsic, current_statement_));
+                                call, result.intrinsic, current_statement_));
   SetType(call, result.intrinsic->ReturnType());
   return true;
 }
@@ -566,6 +566,7 @@
                                                     var->storage_class));
     }
 
+    var->users.push_back(expr);
     set_referenced_from_function_if_needed(var, true);
     return true;
   }
@@ -818,7 +819,7 @@
 
   builder_->Sem().Add(expr,
                       builder_->create<semantic::MemberAccessorExpression>(
-                          ret, current_statement_, is_swizzle));
+                          expr, ret, current_statement_, is_swizzle));
   SetType(expr, ret);
 
   return true;
@@ -934,8 +935,20 @@
   for (auto it : variable_to_info_) {
     auto* var = it.first;
     auto* info = it.second;
-    sem.Add(var,
-            builder_->create<semantic::Variable>(var, info->storage_class));
+    std::vector<const semantic::Expression*> users;
+    for (auto* user : info->users) {
+      // Create semantic node for the identifier expression if necessary
+      auto* sem_expr = sem.Get(user);
+      if (sem_expr == nullptr) {
+        auto* type = expr_info_.at(user).type;
+        auto* stmt = expr_info_.at(user).statement;
+        sem_expr = builder_->create<semantic::Expression>(user, type, stmt);
+        sem.Add(user, sem_expr);
+      }
+      users.push_back(sem_expr);
+    }
+    sem.Add(var, builder_->create<semantic::Variable>(var, info->storage_class,
+                                                      std::move(users)));
   }
 
   auto remap_vars = [&sem](const std::vector<VariableInfo*>& in) {
@@ -965,7 +978,8 @@
     auto* call = it.first;
     auto info = it.second;
     auto* sem_func = func_info_to_sem_func.at(info.function);
-    sem.Add(call, builder_->create<semantic::Call>(sem_func, info.statement));
+    sem.Add(call,
+            builder_->create<semantic::Call>(call, sem_func, info.statement));
   }
 
   // Create semantic nodes for all remaining expression types
@@ -976,8 +990,8 @@
       // Expression has already been assigned a semantic node
       continue;
     }
-    sem.Add(expr,
-            builder_->create<semantic::Expression>(info.type, info.statement));
+    sem.Add(expr, builder_->create<semantic::Expression>(expr, info.type,
+                                                         info.statement));
   }
 }
 
diff --git a/src/type_determiner.h b/src/type_determiner.h
index 7f8ecc2..b03b1d5 100644
--- a/src/type_determiner.h
+++ b/src/type_determiner.h
@@ -99,13 +99,14 @@
   };
 
   /// Structure holding semantic information about a variable.
-  /// Used to build the semantic::Function nodes at the end of resolving.
+  /// Used to build the semantic::Variable nodes at the end of resolving.
   struct VariableInfo {
     explicit VariableInfo(ast::Variable* decl);
     ~VariableInfo();
 
     ast::Variable* const declaration;
     ast::StorageClass storage_class;
+    std::vector<ast::IdentifierExpression*> users;
   };
 
   /// Structure holding semantic information about a function.
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index 9f871ba..71d48a3 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -111,6 +111,20 @@
     return sem_stmt ? sem_stmt->Declaration() : nullptr;
   }
 
+  bool CheckVarUsers(ast::Variable* var,
+                     std::vector<ast::Expression*>&& expected_users) {
+    auto& var_users = Sem().Get(var)->Users();
+    if (var_users.size() != expected_users.size()) {
+      return false;
+    }
+    for (size_t i = 0; i < var_users.size(); i++) {
+      if (var_users[i]->Declaration() != expected_users[i]) {
+        return false;
+      }
+    }
+    return true;
+  }
+
  private:
   std::unique_ptr<TypeDeterminer> td_;
 };
@@ -468,6 +482,8 @@
   EXPECT_EQ(StmtOf(bar_i32_init), bar_i32_decl);
   EXPECT_EQ(StmtOf(foo_f32_init), foo_f32_decl);
   EXPECT_EQ(StmtOf(bar_f32_init), bar_f32_decl);
+  EXPECT_TRUE(CheckVarUsers(foo_i32, {bar_i32->constructor()}));
+  EXPECT_TRUE(CheckVarUsers(foo_f32, {bar_f32->constructor()}));
 }
 
 TEST_F(TypeDeterminerTest, Stmt_VariableDecl_ModuleScopeAfterFunctionScope) {
@@ -513,6 +529,8 @@
   EXPECT_EQ(StmtOf(fn_i32_init), fn_i32_decl);
   EXPECT_EQ(StmtOf(mod_init), nullptr);
   EXPECT_EQ(StmtOf(fn_f32_init), fn_f32_decl);
+  EXPECT_TRUE(CheckVarUsers(fn_i32, {}));
+  EXPECT_TRUE(CheckVarUsers(mod_f32, {fn_f32->constructor()}));
 }
 
 TEST_F(TypeDeterminerTest, Expr_Error_Unknown) {
@@ -716,7 +734,7 @@
 }
 
 TEST_F(TypeDeterminerTest, Expr_Identifier_GlobalVariable) {
-  Global("my_var", ast::StorageClass::kNone, ty.f32());
+  auto* my_var = Global("my_var", ast::StorageClass::kNone, ty.f32());
 
   auto* ident = Expr("my_var");
   WrapInFunction(ident);
@@ -726,10 +744,11 @@
   ASSERT_NE(TypeOf(ident), nullptr);
   EXPECT_TRUE(TypeOf(ident)->Is<type::Pointer>());
   EXPECT_TRUE(TypeOf(ident)->As<type::Pointer>()->type()->Is<type::F32>());
+  EXPECT_TRUE(CheckVarUsers(my_var, {ident}));
 }
 
 TEST_F(TypeDeterminerTest, Expr_Identifier_GlobalConstant) {
-  GlobalConst("my_var", ast::StorageClass::kNone, ty.f32());
+  auto* my_var = GlobalConst("my_var", ast::StorageClass::kNone, ty.f32());
 
   auto* ident = Expr("my_var");
   WrapInFunction(ident);
@@ -738,6 +757,7 @@
 
   ASSERT_NE(TypeOf(ident), nullptr);
   EXPECT_TRUE(TypeOf(ident)->Is<type::F32>());
+  EXPECT_TRUE(CheckVarUsers(my_var, {ident}));
 }
 
 TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable_Const) {
@@ -761,6 +781,7 @@
   ASSERT_NE(TypeOf(my_var_b), nullptr);
   EXPECT_TRUE(TypeOf(my_var_b)->Is<type::F32>());
   EXPECT_EQ(StmtOf(my_var_b), assign);
+  EXPECT_TRUE(CheckVarUsers(var, {my_var_a, my_var_b}));
 }
 
 TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable) {
@@ -768,10 +789,11 @@
   auto* my_var_b = Expr("my_var");
   auto* assign = create<ast::AssignmentStatement>(my_var_a, my_var_b);
 
+  auto* var = Var("my_var", ast::StorageClass::kNone, ty.f32());
+
   Func("my_func", ast::VariableList{}, ty.f32(),
        ast::StatementList{
-           create<ast::VariableDeclStatement>(
-               Var("my_var", ast::StorageClass::kNone, ty.f32())),
+           create<ast::VariableDeclStatement>(var),
            assign,
        },
        ast::FunctionDecorationList{});
@@ -786,6 +808,7 @@
   EXPECT_TRUE(TypeOf(my_var_b)->Is<type::Pointer>());
   EXPECT_TRUE(TypeOf(my_var_b)->As<type::Pointer>()->type()->Is<type::F32>());
   EXPECT_EQ(StmtOf(my_var_b), assign);
+  EXPECT_TRUE(CheckVarUsers(var, {my_var_a, my_var_b}));
 }
 
 TEST_F(TypeDeterminerTest, Expr_Identifier_Function_Ptr) {
diff --git a/src/writer/append_vector.cc b/src/writer/append_vector.cc
index 3e3d841..95085ec 100644
--- a/src/writer/append_vector.cc
+++ b/src/writer/append_vector.cc
@@ -56,8 +56,8 @@
 
   // Cast scalar to the vector element type
   auto* scalar_cast = b->Construct(packed_el_ty, scalar);
-  b->Sem().Add(scalar_cast,
-               b->create<semantic::Expression>(packed_el_ty, statement));
+  b->Sem().Add(scalar_cast, b->create<semantic::Expression>(
+                                scalar_cast, packed_el_ty, statement));
 
   auto* packed_ty = b->create<type::Vector>(packed_el_ty, packed_size);
 
@@ -76,8 +76,8 @@
   }
 
   auto* constructor = b->Construct(packed_ty, std::move(packed));
-  b->Sem().Add(constructor,
-               b->create<semantic::Expression>(packed_ty, statement));
+  b->Sem().Add(constructor, b->create<semantic::Expression>(
+                                constructor, packed_ty, statement));
 
   return constructor;
 }
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 4d9c78b..97bee9b 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -985,7 +985,8 @@
     auto* i32 = builder_.create<type::I32>();
     auto* zero = builder_.Expr(0);
     auto* stmt = builder_.Sem().Get(vector)->Stmt();
-    builder_.Sem().Add(zero, builder_.create<semantic::Expression>(i32, stmt));
+    builder_.Sem().Add(zero,
+                       builder_.create<semantic::Expression>(zero, i32, stmt));
     auto* packed = AppendVector(&builder_, vector, zero);
     return EmitExpression(pre, out, packed);
   };