Validate function call arguments

- Add resolver/call_test.cc for new unit tests, and move a couple that
were in resolver/validation_test.cc to it

- Fix CalculateArrayLength transform so that it passes the address of
the u32 it creates to the internal function

- Fix tests broken as a result of this change

Bug: tint:664
Change-Id: If713f9828790cd51224d2392d42c01c0057cb652
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/53920
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 25c90ed..5a9494e 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -589,6 +589,7 @@
     resolver/assignment_validation_test.cc
     resolver/block_test.cc
     resolver/builtins_validation_test.cc
+    resolver/call_test.cc
     resolver/control_block_validation_test.cc
     resolver/decoration_validation_test.cc
     resolver/entry_point_validation_test.cc
diff --git a/src/program_builder.h b/src/program_builder.h
index 02eae06..b04a5ad 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -1403,11 +1403,24 @@
                                           Expr(std::forward<EXPR>(expr)));
   }
 
+  /// @param source the source information
   /// @param func the function name
   /// @param args the function call arguments
   /// @returns a `ast::CallExpression` to the function `func`, with the
   /// arguments of `args` converted to `ast::Expression`s using `Expr()`.
   template <typename NAME, typename... ARGS>
+  ast::CallExpression* Call(const Source& source, NAME&& func, ARGS&&... args) {
+    return create<ast::CallExpression>(source, Expr(func),
+                                       ExprList(std::forward<ARGS>(args)...));
+  }
+
+  /// @param func the function name
+  /// @param args the function call arguments
+  /// @returns a `ast::CallExpression` to the function `func`, with the
+  /// arguments of `args` converted to `ast::Expression`s using `Expr()`.
+  template <typename NAME,
+            typename... ARGS,
+            traits::EnableIfIsNotType<traits::Decay<NAME>, Source>* = nullptr>
   ast::CallExpression* Call(NAME&& func, ARGS&&... args) {
     return create<ast::CallExpression>(Expr(func),
                                        ExprList(std::forward<ARGS>(args)...));
diff --git a/src/resolver/call_test.cc b/src/resolver/call_test.cc
new file mode 100644
index 0000000..f9e2af5
--- /dev/null
+++ b/src/resolver/call_test.cc
@@ -0,0 +1,194 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/resolver/resolver.h"
+
+#include "gmock/gmock.h"
+#include "src/ast/call_statement.h"
+#include "src/resolver/resolver_test_helper.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+// Helpers and typedefs
+template <typename T>
+using DataType = builder::DataType<T>;
+template <int N, typename T>
+using vec = builder::vec<N, T>;
+template <typename T>
+using vec2 = builder::vec2<T>;
+template <typename T>
+using vec3 = builder::vec3<T>;
+template <typename T>
+using vec4 = builder::vec4<T>;
+template <int N, int M, typename T>
+using mat = builder::mat<N, M, T>;
+template <typename T>
+using mat2x2 = builder::mat2x2<T>;
+template <typename T>
+using mat2x3 = builder::mat2x3<T>;
+template <typename T>
+using mat3x2 = builder::mat3x2<T>;
+template <typename T>
+using mat3x3 = builder::mat3x3<T>;
+template <typename T>
+using mat4x4 = builder::mat4x4<T>;
+template <typename T, int ID = 0>
+using alias = builder::alias<T, ID>;
+template <typename T>
+using alias1 = builder::alias1<T>;
+template <typename T>
+using alias2 = builder::alias2<T>;
+template <typename T>
+using alias3 = builder::alias3<T>;
+using f32 = builder::f32;
+using i32 = builder::i32;
+using u32 = builder::u32;
+
+using ResolverCallTest = ResolverTest;
+
+TEST_F(ResolverCallTest, Recursive_Invalid) {
+  // fn main() {main(); }
+
+  SetSource(Source::Location{12, 34});
+  auto* call_expr = Call("main");
+  ast::VariableList params0;
+
+  Func("main", params0, ty.void_(),
+       ast::StatementList{
+           create<ast::CallStatement>(call_expr),
+       },
+       ast::DecorationList{
+           Stage(ast::PipelineStage::kVertex),
+       });
+
+  EXPECT_FALSE(r()->Resolve());
+
+  EXPECT_EQ(r()->error(),
+            "12:34 error v-0004: recursion is not permitted. 'main' attempted "
+            "to call "
+            "itself.");
+}
+
+TEST_F(ResolverCallTest, Undeclared_Invalid) {
+  // fn main() {func(); return; }
+  // fn func() { return; }
+
+  SetSource(Source::Location{12, 34});
+  auto* call_expr = Call("func");
+  ast::VariableList params0;
+
+  Func("main", params0, ty.f32(),
+       ast::StatementList{
+           create<ast::CallStatement>(call_expr),
+           Return(),
+       },
+       ast::DecorationList{});
+
+  Func("func", params0, ty.f32(),
+       ast::StatementList{
+           Return(),
+       },
+       ast::DecorationList{});
+
+  EXPECT_FALSE(r()->Resolve());
+
+  EXPECT_EQ(r()->error(),
+            "12:34 error: v-0006: unable to find called function: func");
+}
+
+struct Params {
+  builder::ast_expr_func_ptr create_value;
+  builder::ast_type_func_ptr create_type;
+};
+
+template <typename T>
+constexpr Params ParamsFor() {
+  return Params{DataType<T>::Expr, DataType<T>::AST};
+}
+
+static constexpr Params all_param_types[] = {
+    ParamsFor<bool>(),         //
+    ParamsFor<u32>(),          //
+    ParamsFor<i32>(),          //
+    ParamsFor<f32>(),          //
+    ParamsFor<vec3<bool>>(),   //
+    ParamsFor<vec3<i32>>(),    //
+    ParamsFor<vec3<u32>>(),    //
+    ParamsFor<vec3<f32>>(),    //
+    ParamsFor<mat3x3<i32>>(),  //
+    ParamsFor<mat3x3<u32>>(),  //
+    ParamsFor<mat3x3<f32>>(),  //
+    ParamsFor<mat2x3<i32>>(),  //
+    ParamsFor<mat2x3<u32>>(),  //
+    ParamsFor<mat2x3<f32>>(),  //
+    ParamsFor<mat3x2<i32>>(),  //
+    ParamsFor<mat3x2<u32>>(),  //
+    ParamsFor<mat3x2<f32>>()   //
+};
+
+TEST_F(ResolverCallTest, Valid) {
+  ast::VariableList params;
+  ast::ExpressionList args;
+  for (auto& p : all_param_types) {
+    params.push_back(Param(Sym(), p.create_type(*this)));
+    args.push_back(p.create_value(*this, 0));
+  }
+
+  Func("foo", std::move(params), ty.void_(), {Return()});
+  auto* call = Call("foo", std::move(args));
+  WrapInFunction(call);
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverCallTest, TooFewArgs) {
+  Func("foo", {Param(Sym(), ty.i32()), Param(Sym(), ty.f32())}, ty.void_(),
+       {Return()});
+  auto* call = Call(Source{{12, 34}}, "foo", 1);
+  WrapInFunction(call);
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(
+      r()->error(),
+      "12:34 error: too few arguments in call to 'foo', expected 2, got 1");
+}
+
+TEST_F(ResolverCallTest, TooManyArgs) {
+  Func("foo", {Param(Sym(), ty.i32()), Param(Sym(), ty.f32())}, ty.void_(),
+       {Return()});
+  auto* call = Call(Source{{12, 34}}, "foo", 1, 1.0f, 1.0f);
+  WrapInFunction(call);
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(
+      r()->error(),
+      "12:34 error: too many arguments in call to 'foo', expected 2, got 3");
+}
+
+TEST_F(ResolverCallTest, MismatchedArgs) {
+  Func("foo", {Param(Sym(), ty.i32()), Param(Sym(), ty.f32())}, ty.void_(),
+       {Return()});
+  auto* call = Call("foo", Expr(Source{{12, 34}}, true), 1.0f);
+  WrapInFunction(call);
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: type mismatch for argument 1 in call to 'foo', "
+            "expected 'i32', got 'bool'");
+}
+
+}  // namespace
+}  // namespace resolver
+}  // namespace tint
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 27fbd1b..41c4114 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -1701,51 +1701,9 @@
       return false;
     }
   } else {
-    if (current_function_) {
-      auto callee_func_it = symbol_to_function_.find(ident->symbol());
-      if (callee_func_it == symbol_to_function_.end()) {
-        if (current_function_->declaration->symbol() == ident->symbol()) {
-          diagnostics_.add_error("v-0004",
-                                 "recursion is not permitted. '" + name +
-                                     "' attempted to call itself.",
-                                 call->source());
-        } else {
-          diagnostics_.add_error(
-              "v-0006: unable to find called function: " + name,
-              call->source());
-        }
-        return false;
-      }
-      auto* callee_func = callee_func_it->second;
-
-      callee_func->callsites.push_back(call);
-
-      // Note: Requires called functions to be resolved first.
-      // This is currently guaranteed as functions must be declared before
-      // use.
-      current_function_->transitive_calls.add(callee_func);
-      for (auto* transitive_call : callee_func->transitive_calls) {
-        current_function_->transitive_calls.add(transitive_call);
-      }
-
-      // We inherit any referenced variables from the callee.
-      for (auto* var : callee_func->referenced_module_vars) {
-        set_referenced_from_function_if_needed(var, false);
-      }
-    }
-
-    auto iter = symbol_to_function_.find(ident->symbol());
-    if (iter == symbol_to_function_.end()) {
-      diagnostics_.add_error(
-          "v-0005: function must be declared before use: '" + name + "'",
-          call->source());
+    if (!FunctionCall(call)) {
       return false;
     }
-
-    auto* function = iter->second;
-    function_calls_.emplace(call,
-                            FunctionCallInfo{function, current_statement_});
-    SetType(call, function->return_type, function->return_type_name);
   }
 
   return true;
@@ -1775,6 +1733,79 @@
   return true;
 }
 
+bool Resolver::FunctionCall(const ast::CallExpression* call) {
+  auto* ident = call->func();
+  auto name = builder_->Symbols().NameFor(ident->symbol());
+
+  auto callee_func_it = symbol_to_function_.find(ident->symbol());
+  if (callee_func_it == symbol_to_function_.end()) {
+    if (current_function_ &&
+        current_function_->declaration->symbol() == ident->symbol()) {
+      diagnostics_.add_error("v-0004",
+                             "recursion is not permitted. '" + name +
+                                 "' attempted to call itself.",
+                             call->source());
+    } else {
+      diagnostics_.add_error("v-0006: unable to find called function: " + name,
+                             call->source());
+    }
+    return false;
+  }
+  auto* callee_func = callee_func_it->second;
+
+  if (current_function_) {
+    callee_func->callsites.push_back(call);
+
+    // Note: Requires called functions to be resolved first.
+    // This is currently guaranteed as functions must be declared before
+    // use.
+    current_function_->transitive_calls.add(callee_func);
+    for (auto* transitive_call : callee_func->transitive_calls) {
+      current_function_->transitive_calls.add(transitive_call);
+    }
+
+    // We inherit any referenced variables from the callee.
+    for (auto* var : callee_func->referenced_module_vars) {
+      set_referenced_from_function_if_needed(var, false);
+    }
+  }
+
+  // Validate number of arguments match number of parameters
+  if (call->params().size() != callee_func->parameters.size()) {
+    bool more = call->params().size() > callee_func->parameters.size();
+    diagnostics_.add_error(
+        "too " + (more ? std::string("many") : std::string("few")) +
+            " arguments in call to '" + name + "', expected " +
+            std::to_string(callee_func->parameters.size()) + ", got " +
+            std::to_string(call->params().size()),
+        call->source());
+    return false;
+  }
+
+  // Validate arguments match parameter types
+  for (size_t i = 0; i < call->params().size(); ++i) {
+    const VariableInfo* param = callee_func->parameters[i];
+    const ast::Expression* arg_expr = call->params()[i];
+    auto* arg_type = TypeOf(arg_expr)->UnwrapRef();
+
+    if (param->type != arg_type) {
+      diagnostics_.add_error(
+          "type mismatch for argument " + std::to_string(i + 1) +
+              " in call to '" + name + "', expected '" +
+              param->type->FriendlyName(builder_->Symbols()) + "', got '" +
+              arg_type->FriendlyName(builder_->Symbols()) + "'",
+          arg_expr->source());
+      return false;
+    }
+  }
+
+  function_calls_.emplace(call,
+                          FunctionCallInfo{callee_func, current_statement_});
+  SetType(call, callee_func->return_type, callee_func->return_type_name);
+
+  return true;
+}
+
 bool Resolver::Constructor(ast::ConstructorExpression* expr) {
   if (auto* type_ctor = expr->As<ast::TypeConstructorExpression>()) {
     for (auto* value : type_ctor->values()) {
@@ -2514,11 +2545,11 @@
   return nullptr;
 }
 
-void Resolver::SetType(ast::Expression* expr, const sem::Type* type) {
+void Resolver::SetType(const ast::Expression* expr, const sem::Type* type) {
   SetType(expr, type, type->FriendlyName(builder_->Symbols()));
 }
 
-void Resolver::SetType(ast::Expression* expr,
+void Resolver::SetType(const ast::Expression* expr,
                        const sem::Type* type,
                        const std::string& type_name) {
   if (expr_info_.count(expr)) {
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index 03715fd..9f91027 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -234,6 +234,7 @@
   bool Identifier(ast::IdentifierExpression*);
   bool IfStatement(ast::IfStatement*);
   bool IntrinsicCall(ast::CallExpression*, sem::IntrinsicType);
+  bool FunctionCall(const ast::CallExpression* call);
   bool LoopStatement(ast::LoopStatement*);
   bool MemberAccessor(ast::MemberAccessorExpression*);
   bool Parameter(ast::Variable* param);
@@ -345,7 +346,7 @@
   /// assigns this semantic node to the expression `expr`.
   /// @param expr the expression
   /// @param type the resolved type
-  void SetType(ast::Expression* expr, const sem::Type* type);
+  void SetType(const ast::Expression* expr, const sem::Type* type);
 
   /// Creates a sem::Expression node with the resolved type `type`, the declared
   /// type name `type_name` and assigns this semantic node to the expression
@@ -353,7 +354,7 @@
   /// @param expr the expression
   /// @param type the resolved type
   /// @param type_name the declared type name
-  void SetType(ast::Expression* expr,
+  void SetType(const ast::Expression* expr,
                const sem::Type* type,
                const std::string& type_name);
 
@@ -396,7 +397,8 @@
   std::vector<FunctionInfo*> entry_points_;
   std::unordered_map<const ast::Function*, FunctionInfo*> function_to_info_;
   std::unordered_map<const ast::Variable*, VariableInfo*> variable_to_info_;
-  std::unordered_map<ast::CallExpression*, FunctionCallInfo> function_calls_;
+  std::unordered_map<const ast::CallExpression*, FunctionCallInfo>
+      function_calls_;
   std::unordered_map<const ast::Expression*, ExpressionInfo> expr_info_;
   std::unordered_map<Symbol, TypeDeclInfo> named_type_info_;
 
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index d225fbb..e060682 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -560,8 +560,7 @@
 }
 
 TEST_F(ResolverTest, Expr_Call_WithParams) {
-  ast::VariableList params;
-  Func("my_func", params, ty.f32(),
+  Func("my_func", {Param(Sym(), ty.f32())}, ty.f32(),
        {
            Return(1.2f),
        });
diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc
index a0a5f23..1c62b89 100644
--- a/src/resolver/validation_test.cc
+++ b/src/resolver/validation_test.cc
@@ -83,56 +83,6 @@
             "2:30 error: unknown statement type for type determination: Fake");
 }
 
-TEST_F(ResolverValidationTest, Stmt_Call_undeclared) {
-  // fn main() {func(); return; }
-  // fn func() { return; }
-
-  SetSource(Source::Location{12, 34});
-  auto* call_expr = Call("func");
-  ast::VariableList params0;
-
-  Func("main", params0, ty.f32(),
-       ast::StatementList{
-           create<ast::CallStatement>(call_expr),
-           Return(),
-       },
-       ast::DecorationList{});
-
-  Func("func", params0, ty.f32(),
-       ast::StatementList{
-           Return(),
-       },
-       ast::DecorationList{});
-
-  EXPECT_FALSE(r()->Resolve());
-
-  EXPECT_EQ(r()->error(),
-            "12:34 error: v-0006: unable to find called function: func");
-}
-
-TEST_F(ResolverValidationTest, Stmt_Call_recursive) {
-  // fn main() {main(); }
-
-  SetSource(Source::Location{12, 34});
-  auto* call_expr = Call("main");
-  ast::VariableList params0;
-
-  Func("main", params0, ty.void_(),
-       ast::StatementList{
-           create<ast::CallStatement>(call_expr),
-       },
-       ast::DecorationList{
-           Stage(ast::PipelineStage::kVertex),
-       });
-
-  EXPECT_FALSE(r()->Resolve());
-
-  EXPECT_EQ(r()->error(),
-            "12:34 error v-0004: recursion is not permitted. 'main' attempted "
-            "to call "
-            "itself.");
-}
-
 TEST_F(ResolverValidationTest, Stmt_If_NonBool) {
   // if (1.23f) {}
 
diff --git a/src/sem/call.cc b/src/sem/call.cc
index baa4425..3abb91e 100644
--- a/src/sem/call.cc
+++ b/src/sem/call.cc
@@ -19,7 +19,7 @@
 namespace tint {
 namespace sem {
 
-Call::Call(ast::Expression* declaration,
+Call::Call(const ast::Expression* declaration,
            const CallTarget* target,
            Statement* statement)
     : Base(declaration, target->ReturnType(), statement), target_(target) {}
diff --git a/src/sem/call.h b/src/sem/call.h
index a3e3da7..d2fdb31 100644
--- a/src/sem/call.h
+++ b/src/sem/call.h
@@ -29,7 +29,7 @@
   /// @param declaration the AST node
   /// @param target the call target
   /// @param statement the statement that owns this expression
-  Call(ast::Expression* declaration,
+  Call(const ast::Expression* declaration,
        const CallTarget* target,
        Statement* statement);
 
diff --git a/src/sem/expression.cc b/src/sem/expression.cc
index 74fb0c2..7286dc6 100644
--- a/src/sem/expression.cc
+++ b/src/sem/expression.cc
@@ -19,7 +19,7 @@
 namespace tint {
 namespace sem {
 
-Expression::Expression(ast::Expression* declaration,
+Expression::Expression(const ast::Expression* declaration,
                        const sem::Type* type,
                        Statement* statement)
     : declaration_(declaration), type_(type), statement_(statement) {
diff --git a/src/sem/expression.h b/src/sem/expression.h
index 73d971f..8c2e304 100644
--- a/src/sem/expression.h
+++ b/src/sem/expression.h
@@ -31,7 +31,7 @@
   /// @param declaration the AST node
   /// @param type the resolved type of the expression
   /// @param statement the statement that owns this expression
-  Expression(ast::Expression* declaration,
+  Expression(const ast::Expression* declaration,
              const sem::Type* type,
              Statement* statement);
 
@@ -42,10 +42,12 @@
   Statement* Stmt() const { return statement_; }
 
   /// @returns the AST node
-  ast::Expression* Declaration() const { return declaration_; }
+  ast::Expression* Declaration() const {
+    return const_cast<ast::Expression*>(declaration_);
+  }
 
  private:
-  ast::Expression* declaration_;
+  const ast::Expression* declaration_;
   const sem::Type* const type_;
   Statement* const statement_;
 };
diff --git a/src/traits.h b/src/traits.h
index f35fe0f..1fda4f0 100644
--- a/src/traits.h
+++ b/src/traits.h
@@ -20,6 +20,10 @@
 namespace tint {
 namespace traits {
 
+/// Convience type definition for std::decay<T>::type
+template <typename T>
+using Decay = typename std::decay<T>::type;
+
 /// NthTypeOf returns the `N`th type in `Types`
 template <int N, typename... Types>
 using NthTypeOf = typename std::tuple_element<N, std::tuple<Types...>>::type;
@@ -38,7 +42,7 @@
   /// Arg is the raw type of the `N`th parameter of the function
   using Arg = NthTypeOf<N, Args...>;
   /// The type of the `N`th parameter of the function
-  using type = typename std::decay<Arg>::type;
+  using type = Decay<Arg>;
 };
 
 /// ParamType specialization for a non-static method.
@@ -47,7 +51,7 @@
   /// Arg is the raw type of the `N`th parameter of the function
   using Arg = NthTypeOf<N, Args...>;
   /// The type of the `N`th parameter of the function
-  using type = typename std::decay<Arg>::type;
+  using type = Decay<Arg>;
 };
 
 /// ParamType specialization for a non-static, const method.
@@ -56,7 +60,7 @@
   /// Arg is the raw type of the `N`th parameter of the function
   using Arg = NthTypeOf<N, Args...>;
   /// The type of the `N`th parameter of the function
-  using type = typename std::decay<Arg>::type;
+  using type = Decay<Arg>;
 };
 
 /// ParamTypeT is an alias to `typename ParamType<F, N>::type`.
@@ -66,21 +70,26 @@
 /// `IsTypeOrDerived<T, BASE>::value` is true iff `T` is of type `BASE`, or
 /// derives from `BASE`.
 template <typename T, typename BASE>
-using IsTypeOrDerived = std::integral_constant<
-    bool,
-    std::is_base_of<BASE, typename std::decay<T>::type>::value ||
-        std::is_same<BASE, typename std::decay<T>::type>::value>;
+using IsTypeOrDerived =
+    std::integral_constant<bool,
+                           std::is_base_of<BASE, Decay<T>>::value ||
+                               std::is_same<BASE, Decay<T>>::value>;
 
 /// If `CONDITION` is true then EnableIf resolves to type T, otherwise an
 /// invalid type.
 template <bool CONDITION, typename T>
 using EnableIf = typename std::enable_if<CONDITION, T>::type;
 
-/// If T is a base of BASE then EnableIfIsType resolves to type T, otherwise an
-/// invalid type.
+/// If `T` is of type `BASE`, or derives from `BASE`, then EnableIfIsType
+/// resolves to type `T`, otherwise an invalid type.
 template <typename T, typename BASE>
 using EnableIfIsType = EnableIf<IsTypeOrDerived<T, BASE>::value, T>;
 
+/// If `T` is not of type `BASE`, or does not derive from `BASE`, then
+/// EnableIfIsNotType resolves to type `T`, otherwise an invalid type.
+template <typename T, typename BASE>
+using EnableIfIsNotType = EnableIf<!IsTypeOrDerived<T, BASE>::value, T>;
+
 }  // namespace traits
 }  // namespace tint
 
diff --git a/src/transform/calculate_array_length.cc b/src/transform/calculate_array_length.cc
index 0feadc1..74962fb 100644
--- a/src/transform/calculate_array_length.cc
+++ b/src/transform/calculate_array_length.cc
@@ -182,14 +182,15 @@
                     ctx.dst->Var(ctx.dst->Sym(), ctx.dst->ty.u32(),
                                  ast::StorageClass::kNone, ctx.dst->Expr(0u)));
 
-                // Call storage_buffer.GetDimensions(buffer_size_result)
+                // Call storage_buffer.GetDimensions(&buffer_size_result)
                 auto* call_get_dims =
                     ctx.dst->create<ast::CallStatement>(ctx.dst->Call(
                         // BufferSizeIntrinsic(X, ARGS...) is
                         // translated to:
                         //  X.GetDimensions(ARGS..) by the writer
                         buffer_size, ctx.Clone(storage_buffer_expr),
-                        buffer_size_result->variable()->symbol()));
+                        ctx.dst->AddressOf(ctx.dst->Expr(
+                            buffer_size_result->variable()->symbol()))));
 
                 // Calculate actual array length
                 //                total_storage_buffer_size - array_offset
diff --git a/src/transform/calculate_array_length_test.cc b/src/transform/calculate_array_length_test.cc
index 557454b..4821c58 100644
--- a/src/transform/calculate_array_length_test.cc
+++ b/src/transform/calculate_array_length_test.cc
@@ -53,7 +53,7 @@
 [[stage(compute)]]
 fn main() {
   var tint_symbol_1 : u32 = 0u;
-  tint_symbol(sb, tint_symbol_1);
+  tint_symbol(sb, &(tint_symbol_1));
   let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
   var len : u32 = tint_symbol_2;
 }
@@ -97,7 +97,7 @@
 [[stage(compute)]]
 fn main() {
   var tint_symbol_1 : u32 = 0u;
-  tint_symbol(sb, tint_symbol_1);
+  tint_symbol(sb, &(tint_symbol_1));
   let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
   var a : u32 = tint_symbol_2;
   var b : u32 = tint_symbol_2;
@@ -143,7 +143,7 @@
 [[stage(compute)]]
 fn main() {
   var tint_symbol_1 : u32 = 0u;
-  tint_symbol(sb, tint_symbol_1);
+  tint_symbol(sb, &(tint_symbol_1));
   let tint_symbol_2 : u32 = ((tint_symbol_1 - 8u) / 64u);
   var len : u32 = tint_symbol_2;
 }
@@ -192,13 +192,13 @@
 fn main() {
   if (true) {
     var tint_symbol_1 : u32 = 0u;
-    tint_symbol(sb, tint_symbol_1);
+    tint_symbol(sb, &(tint_symbol_1));
     let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
     var len : u32 = tint_symbol_2;
   } else {
     if (true) {
       var tint_symbol_3 : u32 = 0u;
-      tint_symbol(sb, tint_symbol_3);
+      tint_symbol(sb, &(tint_symbol_3));
       let tint_symbol_4 : u32 = ((tint_symbol_3 - 4u) / 4u);
       var len : u32 = tint_symbol_4;
     }
@@ -263,10 +263,10 @@
 [[stage(compute)]]
 fn main() {
   var tint_symbol_1 : u32 = 0u;
-  tint_symbol(sb1, tint_symbol_1);
+  tint_symbol(sb1, &(tint_symbol_1));
   let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
   var tint_symbol_4 : u32 = 0u;
-  tint_symbol_3(sb2, tint_symbol_4);
+  tint_symbol_3(sb2, &(tint_symbol_4));
   let tint_symbol_5 : u32 = ((tint_symbol_4 - 16u) / 16u);
   var len1 : u32 = tint_symbol_2;
   var len2 : u32 = tint_symbol_5;
diff --git a/src/writer/hlsl/generator_impl_binary_test.cc b/src/writer/hlsl/generator_impl_binary_test.cc
index 610e517..743bb2b 100644
--- a/src/writer/hlsl/generator_impl_binary_test.cc
+++ b/src/writer/hlsl/generator_impl_binary_test.cc
@@ -483,8 +483,13 @@
 TEST_F(HlslGeneratorImplTest_Binary, Call_WithLogical) {
   // foo(a && b, c || d, (a || c) && (b || d))
 
-  Func("foo", ast::VariableList{}, ty.void_(), ast::StatementList{},
-       ast::DecorationList{});
+  Func("foo",
+       {
+           Param(Sym(), ty.bool_()),
+           Param(Sym(), ty.bool_()),
+           Param(Sym(), ty.bool_()),
+       },
+       ty.void_(), ast::StatementList{}, ast::DecorationList{});
   Global("a", ty.bool_(), ast::StorageClass::kPrivate);
   Global("b", ty.bool_(), ast::StorageClass::kPrivate);
   Global("c", ty.bool_(), ast::StorageClass::kPrivate);
diff --git a/src/writer/hlsl/generator_impl_call_test.cc b/src/writer/hlsl/generator_impl_call_test.cc
index b5bb089..eadfcd4 100644
--- a/src/writer/hlsl/generator_impl_call_test.cc
+++ b/src/writer/hlsl/generator_impl_call_test.cc
@@ -36,8 +36,12 @@
 }
 
 TEST_F(HlslGeneratorImplTest_Call, EmitExpression_Call_WithParams) {
-  Func("my_func", ast::VariableList{}, ty.void_(), ast::StatementList{},
-       ast::DecorationList{});
+  Func("my_func",
+       {
+           Param(Sym(), ty.f32()),
+           Param(Sym(), ty.f32()),
+       },
+       ty.void_(), ast::StatementList{}, ast::DecorationList{});
   Global("param1", ty.f32(), ast::StorageClass::kPrivate);
   Global("param2", ty.f32(), ast::StorageClass::kPrivate);
 
@@ -51,8 +55,12 @@
 }
 
 TEST_F(HlslGeneratorImplTest_Call, EmitStatement_Call) {
-  Func("my_func", ast::VariableList{}, ty.void_(), ast::StatementList{},
-       ast::DecorationList{});
+  Func("my_func",
+       {
+           Param(Sym(), ty.f32()),
+           Param(Sym(), ty.f32()),
+       },
+       ty.void_(), ast::StatementList{}, ast::DecorationList{});
   Global("param1", ty.f32(), ast::StorageClass::kPrivate);
   Global("param2", ty.f32(), ast::StorageClass::kPrivate);
 
diff --git a/src/writer/msl/generator_impl_call_test.cc b/src/writer/msl/generator_impl_call_test.cc
index bd015b6..c165579 100644
--- a/src/writer/msl/generator_impl_call_test.cc
+++ b/src/writer/msl/generator_impl_call_test.cc
@@ -36,8 +36,12 @@
 }
 
 TEST_F(MslGeneratorImplTest, EmitExpression_Call_WithParams) {
-  Func("my_func", ast::VariableList{}, ty.void_(), ast::StatementList{},
-       ast::DecorationList{});
+  Func("my_func",
+       {
+           Param(Sym(), ty.f32()),
+           Param(Sym(), ty.f32()),
+       },
+       ty.void_(), ast::StatementList{}, ast::DecorationList{});
   Global("param1", ty.f32(), ast::StorageClass::kInput);
   Global("param2", ty.f32(), ast::StorageClass::kInput);
 
@@ -51,8 +55,12 @@
 }
 
 TEST_F(MslGeneratorImplTest, EmitStatement_Call) {
-  Func("my_func", ast::VariableList{}, ty.void_(), ast::StatementList{},
-       ast::DecorationList{});
+  Func("my_func",
+       {
+           Param(Sym(), ty.f32()),
+           Param(Sym(), ty.f32()),
+       },
+       ty.void_(), ast::StatementList{}, ast::DecorationList{});
   Global("param1", ty.f32(), ast::StorageClass::kInput);
   Global("param2", ty.f32(), ast::StorageClass::kInput);
 
diff --git a/src/writer/wgsl/generator_impl_call_test.cc b/src/writer/wgsl/generator_impl_call_test.cc
index bf701cd..ecce272 100644
--- a/src/writer/wgsl/generator_impl_call_test.cc
+++ b/src/writer/wgsl/generator_impl_call_test.cc
@@ -36,8 +36,12 @@
 }
 
 TEST_F(WgslGeneratorImplTest, EmitExpression_Call_WithParams) {
-  Func("my_func", ast::VariableList{}, ty.void_(), ast::StatementList{},
-       ast::DecorationList{});
+  Func("my_func",
+       {
+           Param(Sym(), ty.f32()),
+           Param(Sym(), ty.f32()),
+       },
+       ty.void_(), ast::StatementList{}, ast::DecorationList{});
   Global("param1", ty.f32(), ast::StorageClass::kPrivate);
   Global("param2", ty.f32(), ast::StorageClass::kPrivate);
 
@@ -51,8 +55,12 @@
 }
 
 TEST_F(WgslGeneratorImplTest, EmitStatement_Call) {
-  Func("my_func", ast::VariableList{}, ty.void_(), ast::StatementList{},
-       ast::DecorationList{});
+  Func("my_func",
+       {
+           Param(Sym(), ty.f32()),
+           Param(Sym(), ty.f32()),
+       },
+       ty.void_(), ast::StatementList{}, ast::DecorationList{});
   Global("param1", ty.f32(), ast::StorageClass::kPrivate);
   Global("param2", ty.f32(), ast::StorageClass::kPrivate);