spirv-reader: Use GenerateExpressionWithLoadIfNeeded more

* Rename GenerateNonReferenceExpression to
  GenerateExpressionWithLoadIfNeeded.
  This version takes an ast::Expression
* Add a variant that takes a sem::Expression, because the sem
  expression already knows the resolved type, and so we can save
  a lookup.
* Replace most uses of GenerateExpression ... GenerateLoadIfNeeded
  with a call to one of the above.

This is a non-functional change.
Followup to the fix in tint:1343.

Bug: tint:1343
Change-Id: If19a1bc7670edd2badc1533861d8b42f0825c7b8
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/72720
Auto-Submit: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: David Neto <dneto@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index a0eba5a..04334a2 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -408,15 +408,10 @@
     if (lhs_id == 0) {
       return false;
     }
-    auto rhs_id = GenerateExpression(assign->rhs);
+    auto rhs_id = GenerateExpressionWithLoadIfNeeded(assign->rhs);
     if (rhs_id == 0) {
       return false;
     }
-
-    // If the thing we're assigning is a reference then we must load it first.
-    auto* type = TypeOf(assign->rhs);
-    rhs_id = GenerateLoadIfNeeded(type, rhs_id);
-
     return GenerateStore(lhs_id, rhs_id);
   }
 }
@@ -706,14 +701,10 @@
 bool Builder::GenerateFunctionVariable(const ast::Variable* var) {
   uint32_t init_id = 0;
   if (var->constructor) {
-    init_id = GenerateExpression(var->constructor);
+    init_id = GenerateExpressionWithLoadIfNeeded(var->constructor);
     if (init_id == 0) {
       return false;
     }
-    auto* type = TypeOf(var->constructor);
-    if (type->Is<sem::Reference>()) {
-      init_id = GenerateLoadIfNeeded(type, init_id);
-    }
   }
 
   if (var->is_const) {
@@ -914,12 +905,10 @@
 
 bool Builder::GenerateIndexAccessor(const ast::IndexAccessorExpression* expr,
                                     AccessorInfo* info) {
-  auto idx_id = GenerateExpression(expr->index);
+  auto idx_id = GenerateExpressionWithLoadIfNeeded(expr->index);
   if (idx_id == 0) {
     return 0;
   }
-  auto* type = TypeOf(expr->index);
-  idx_id = GenerateLoadIfNeeded(type, idx_id);
 
   // If the source is a reference, we access chain into it.
   // In the future, pointers may support access-chaining.
@@ -1183,8 +1172,19 @@
   return val;
 }
 
-uint32_t Builder::GenerateNonReferenceExpression(const ast::Expression* expr) {
+uint32_t Builder::GenerateExpressionWithLoadIfNeeded(
+    const sem::Expression* expr) {
+  // The semantic node directly knows both the AST node and the resolved type.
+  if (const auto id = GenerateExpression(expr->Declaration())) {
+    return GenerateLoadIfNeeded(expr->Type(), id);
+  }
+  return 0;
+}
+
+uint32_t Builder::GenerateExpressionWithLoadIfNeeded(
+    const ast::Expression* expr) {
   if (const auto id = GenerateExpression(expr)) {
+    // Perform a lookup to get the resolved type.
     return GenerateLoadIfNeeded(TypeOf(expr), id);
   }
   return 0;
@@ -1212,11 +1212,6 @@
   auto result = result_op();
   auto result_id = result.to_i();
 
-  auto val_id = GenerateExpression(expr->expr);
-  if (val_id == 0) {
-    return 0;
-  }
-
   spv::Op op = spv::Op::OpNop;
   switch (expr->op) {
     case ast::UnaryOp::kComplement:
@@ -1237,10 +1232,13 @@
       // Address-of converts a reference to a pointer, and dereference converts
       // a pointer to a reference. These are the same thing in SPIR-V, so this
       // is a no-op.
-      return val_id;
+      return GenerateExpression(expr->expr);
   }
 
-  val_id = GenerateLoadIfNeeded(TypeOf(expr->expr), val_id);
+  auto val_id = GenerateExpressionWithLoadIfNeeded(expr->expr);
+  if (val_id == 0) {
+    return 0;
+  }
 
   auto type_id = GenerateTypeIfNeeded(TypeOf(expr));
   if (type_id == 0) {
@@ -1380,11 +1378,7 @@
   OperandList ops;
   for (auto* e : args) {
     uint32_t id = 0;
-    id = GenerateExpression(e->Declaration());
-    if (id == 0) {
-      return 0;
-    }
-    id = GenerateLoadIfNeeded(e->Type(), id);
+    id = GenerateExpressionWithLoadIfNeeded(e);
     if (id == 0) {
       return 0;
     }
@@ -1532,11 +1526,10 @@
     return 0;
   }
 
-  auto val_id = GenerateExpression(from_expr);
+  auto val_id = GenerateExpressionWithLoadIfNeeded(from_expr);
   if (val_id == 0) {
     return 0;
   }
-  val_id = GenerateLoadIfNeeded(TypeOf(from_expr), val_id);
 
   auto* from_type = TypeOf(from_expr)->UnwrapRef();
 
@@ -1804,11 +1797,10 @@
 
 uint32_t Builder::GenerateShortCircuitBinaryExpression(
     const ast::BinaryExpression* expr) {
-  auto lhs_id = GenerateExpression(expr->lhs);
+  auto lhs_id = GenerateExpressionWithLoadIfNeeded(expr->lhs);
   if (lhs_id == 0) {
     return false;
   }
-  lhs_id = GenerateLoadIfNeeded(TypeOf(expr->lhs), lhs_id);
 
   // Get the ID of the basic block where control flow will diverge. It's the
   // last basic block generated for the left-hand-side of the operator.
@@ -1848,11 +1840,10 @@
   if (!GenerateLabel(block_id)) {
     return 0;
   }
-  auto rhs_id = GenerateExpression(expr->rhs);
+  auto rhs_id = GenerateExpressionWithLoadIfNeeded(expr->rhs);
   if (rhs_id == 0) {
     return 0;
   }
-  rhs_id = GenerateLoadIfNeeded(TypeOf(expr->rhs), rhs_id);
 
   // Get the block ID of the last basic block generated for the right-hand-side
   // expression. That block will be an immediate predecessor to the merge block.
@@ -1971,17 +1962,15 @@
     return GenerateShortCircuitBinaryExpression(expr);
   }
 
-  auto lhs_id = GenerateExpression(expr->lhs);
+  auto lhs_id = GenerateExpressionWithLoadIfNeeded(expr->lhs);
   if (lhs_id == 0) {
     return 0;
   }
-  lhs_id = GenerateLoadIfNeeded(TypeOf(expr->lhs), lhs_id);
 
-  auto rhs_id = GenerateExpression(expr->rhs);
+  auto rhs_id = GenerateExpressionWithLoadIfNeeded(expr->rhs);
   if (rhs_id == 0) {
     return 0;
   }
-  rhs_id = GenerateLoadIfNeeded(TypeOf(expr->rhs), rhs_id);
 
   auto result = result_op();
   auto result_id = result.to_i();
@@ -2258,11 +2247,7 @@
 
   size_t arg_idx = 0;
   for (auto* arg : expr->args) {
-    auto id = GenerateExpression(arg);
-    if (id == 0) {
-      return 0;
-    }
-    id = GenerateLoadIfNeeded(TypeOf(arg), id);
+    auto id = GenerateExpressionWithLoadIfNeeded(arg);
     if (id == 0) {
       return 0;
     }
@@ -2715,12 +2700,7 @@
 
   // Generates the given expression, returning the operand ID
   auto gen = [&](const sem::Expression* expr) {
-    auto val_id = GenerateExpression(expr->Declaration());
-    if (val_id == 0) {
-      return Operand::Int(0);
-    }
-    val_id = GenerateLoadIfNeeded(expr->Type(), val_id);
-
+    const auto val_id = GenerateExpressionWithLoadIfNeeded(expr);
     return Operand::Int(val_id);
   };
 
@@ -3218,11 +3198,7 @@
 
   uint32_t value_id = 0;
   if (call->Arguments().size() > 1) {
-    value_id = GenerateExpression(call->Arguments().back()->Declaration());
-    if (value_id == 0) {
-      return false;
-    }
-    value_id = GenerateLoadIfNeeded(call->Arguments().back()->Type(), value_id);
+    value_id = GenerateExpressionWithLoadIfNeeded(call->Arguments().back());
     if (value_id == 0) {
       return false;
     }
@@ -3458,11 +3434,10 @@
     return 0;
   }
 
-  auto val_id = GenerateExpression(expr->expr);
+  auto val_id = GenerateExpressionWithLoadIfNeeded(expr->expr);
   if (val_id == 0) {
     return 0;
   }
-  val_id = GenerateLoadIfNeeded(TypeOf(expr->expr), val_id);
 
   // Bitcast does not allow same types, just emit a CopyObject
   auto* to_type = TypeOf(expr)->UnwrapRef();
@@ -3489,11 +3464,10 @@
     const ast::BlockStatement* true_body,
     size_t cur_else_idx,
     const ast::ElseStatementList& else_stmts) {
-  auto cond_id = GenerateExpression(cond);
+  auto cond_id = GenerateExpressionWithLoadIfNeeded(cond);
   if (cond_id == 0) {
     return false;
   }
-  cond_id = GenerateLoadIfNeeded(TypeOf(cond), cond_id);
 
   auto merge_block = result_op();
   auto merge_block_id = merge_block.to_i();
@@ -3585,7 +3559,7 @@
     if (is_just_a_break(stmt->body) && stmt->else_statements.empty()) {
       // It's a break-if.
       TINT_ASSERT(Writer, !backedge_stack_.empty());
-      const auto cond_id = GenerateNonReferenceExpression(stmt->condition);
+      const auto cond_id = GenerateExpressionWithLoadIfNeeded(stmt->condition);
       if (!cond_id) {
         return false;
       }
@@ -3600,7 +3574,8 @@
           is_just_a_break(es.back()->body)) {
         // It's a break-unless.
         TINT_ASSERT(Writer, !backedge_stack_.empty());
-        const auto cond_id = GenerateNonReferenceExpression(stmt->condition);
+        const auto cond_id =
+            GenerateExpressionWithLoadIfNeeded(stmt->condition);
         if (!cond_id) {
           return false;
         }
@@ -3626,11 +3601,10 @@
 
   merge_stack_.push_back(merge_block_id);
 
-  auto cond_id = GenerateExpression(stmt->condition);
+  auto cond_id = GenerateExpressionWithLoadIfNeeded(stmt->condition);
   if (cond_id == 0) {
     return false;
   }
-  cond_id = GenerateLoadIfNeeded(TypeOf(stmt->condition), cond_id);
 
   auto default_block = result_op();
   auto default_block_id = default_block.to_i();
@@ -3724,11 +3698,10 @@
 
 bool Builder::GenerateReturnStatement(const ast::ReturnStatement* stmt) {
   if (stmt->value) {
-    auto val_id = GenerateExpression(stmt->value);
+    auto val_id = GenerateExpressionWithLoadIfNeeded(stmt->value);
     if (val_id == 0) {
       return false;
     }
-    val_id = GenerateLoadIfNeeded(TypeOf(stmt->value), val_id);
     if (!push_function_inst(spv::Op::OpReturnValue, {Operand::Int(val_id)})) {
       return false;
     }
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index 1b43ee4..d2b5237 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -458,9 +458,16 @@
   /// type, then return the SPIR-V ID for the expression. Otherwise implement
   /// the WGSL Load Rule: generate an OpLoad and return the ID of the result.
   /// Returns 0 if the expression could not be generated.
-  /// @param expr the expression to be generate
+  /// @param expr the semantic expression node to be generated
   /// @returns the the ID of the expression, or loaded expression
-  uint32_t GenerateNonReferenceExpression(const ast::Expression* expr);
+  uint32_t GenerateExpressionWithLoadIfNeeded(const sem::Expression* expr);
+  /// Generates an expression. If the WGSL expression does not have reference
+  /// type, then return the SPIR-V ID for the expression. Otherwise implement
+  /// the WGSL Load Rule: generate an OpLoad and return the ID of the result.
+  /// Returns 0 if the expression could not be generated.
+  /// @param expr the AST expression to be generated
+  /// @returns the the ID of the expression, or loaded expression
+  uint32_t GenerateExpressionWithLoadIfNeeded(const ast::Expression* expr);
   /// Generates an OpLoad on the given ID if it has reference type in WGSL,
   /// othewrise return the ID itself.
   /// @param type the type of the expression