validation: validate function call pointer parameter

Each argument of a function call of pointer type must be one of:
- An address-of expression of a variable identifier expression
- A function parameter
Also added source location to duplicate struct member name unittest

Bug: tint:983
Change-Id: Ic5ab010b2ed76207a1d8d3ef9f66140ea95f7e72
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/58480
Auto-Submit: Sarah Mashayekhi <sarahmashay@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@chromium.org>
Commit-Queue: Sarah Mashayekhi <sarahmashay@google.com>
diff --git a/src/resolver/call_validation_test.cc b/src/resolver/call_validation_test.cc
index d94b70d..0fc2212 100644
--- a/src/resolver/call_validation_test.cc
+++ b/src/resolver/call_validation_test.cc
@@ -128,6 +128,188 @@
             "intentional wrap the function call in ignore()");
 }
 
+TEST_F(ResolverCallValidationTest, PointerArgument_VariableIdentExpr) {
+  // fn foo(p: ptr<function, i32>) {}
+  // fn main() {
+  //   var z: i32 = 1;
+  //   foo(&z);
+  // }
+  auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
+  Func("foo", {param}, ty.void_(), {});
+  Func("main", {}, ty.void_(),
+       ast::StatementList{
+           Decl(Var("z", ty.i32(), Expr(1))),
+           create<ast::CallStatement>(
+               Call("foo", AddressOf(Source{{12, 34}}, Expr("z")))),
+       });
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverCallValidationTest, PointerArgument_ConstIdentExpr) {
+  // fn foo(p: ptr<function, i32>) {}
+  // fn main() {
+  //   let z: i32 = 1;
+  //   foo(&z);
+  // }
+  auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
+  Func("foo", {param}, ty.void_(), {});
+  Func("main", {}, ty.void_(),
+       ast::StatementList{
+           Decl(Const("z", ty.i32(), Expr(1))),
+           create<ast::CallStatement>(
+               Call("foo", AddressOf(Expr(Source{{12, 34}}, "z")))),
+       });
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
+}
+
+TEST_F(ResolverCallValidationTest, PointerArgument_NotIdentExprVar) {
+  // struct S { m: i32; };
+  // fn foo(p: ptr<function, i32>) {}
+  // fn main() {
+  //   var v: S;
+  //   foo(&v.m);
+  // }
+  auto* S = Structure("S", {Member("m", ty.i32())});
+  auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
+  Func("foo", {param}, ty.void_(), {});
+  Func("main", {}, ty.void_(),
+       ast::StatementList{
+           Decl(Var("v", ty.Of(S))),
+           create<ast::CallStatement>(Call(
+               "foo", AddressOf(Source{{12, 34}}, MemberAccessor("v", "m")))),
+       });
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: expected an address-of expression of a variable "
+            "identifier expression or a function parameter");
+}
+
+TEST_F(ResolverCallValidationTest, PointerArgument_AddressOfMemberAccessor) {
+  // struct S { m: i32; };
+  // fn foo(p: ptr<function, i32>) {}
+  // fn main() {
+  //   let v: S = S();
+  //   foo(&v.m);
+  // }
+  auto* S = Structure("S", {Member("m", ty.i32())});
+  auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
+  Func("foo", {param}, ty.void_(), {});
+  Func("main", {}, ty.void_(),
+       ast::StatementList{
+           Decl(Const("v", ty.Of(S), Construct(ty.Of(S)))),
+           create<ast::CallStatement>(Call(
+               "foo",
+               AddressOf(Expr(Source{{12, 34}}, MemberAccessor("v", "m"))))),
+       });
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
+}
+
+TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParam) {
+  // fn foo(p: ptr<function, i32>) {}
+  // fn bar(p: ptr<function, i32>) {
+  // foo(p);
+  // }
+  Func("foo", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
+       ty.void_(), {});
+  Func("bar", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
+       ty.void_(),
+       ast::StatementList{create<ast::CallStatement>(Call("foo", Expr("p")))});
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParamWithMain) {
+  // fn foo(p: ptr<function, i32>) {}
+  // fn bar(p: ptr<function, i32>) {
+  // foo(p);
+  // }
+  // [[stage(fragment)]]
+  // fn main() {
+  //   var v: i32;
+  //   bar(&v);
+  // }
+  Func("foo", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
+       ty.void_(), {});
+  Func("bar", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
+       ty.void_(),
+       ast::StatementList{create<ast::CallStatement>(Call("foo", Expr("p")))});
+  Func("main", ast::VariableList{}, ty.void_(),
+       {
+           Decl(Var("v", ty.i32(), Expr(1))),
+           create<ast::CallStatement>(Call("foo", AddressOf(Expr("v")))),
+       },
+       {
+           Stage(ast::PipelineStage::kFragment),
+       });
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverCallValidationTest, LetPointer) {
+  // fn x(p : ptr<function, i32>) -> i32 {}
+  // [[stage(fragment)]]
+  // fn main() {
+  //   var v: i32;
+  //   let p: ptr<function, i32> = &v;
+  //   var c: i32 = x(p);
+  // }
+  Func("x", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
+       ty.void_(), {});
+  auto* v = Var("v", ty.i32());
+  auto* p = Const("p", ty.pointer(ty.i32(), ast::StorageClass::kFunction),
+                  AddressOf(v));
+  auto* c = Var("c", ty.i32(), ast::StorageClass::kNone,
+                Call("x", Expr(Source{{12, 34}}, p)));
+  Func("main", ast::VariableList{}, ty.void_(),
+       {
+           Decl(v),
+           Decl(p),
+           Decl(c),
+       },
+       {
+           Stage(ast::PipelineStage::kFragment),
+       });
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: expected an address-of expression of a variable "
+            "identifier expression or a function parameter");
+}
+
+TEST_F(ResolverCallValidationTest, LetPointerPrivate) {
+  // let p: ptr<private, i32> = &v;
+  // fn foo(p : ptr<private, i32>) -> i32 {}
+  // var v: i32;
+  // [[stage(fragment)]]
+  // fn main() {
+  //   var c: i32 = foo(p);
+  // }
+  Func("foo", {Param("p", ty.pointer<i32>(ast::StorageClass::kPrivate))},
+       ty.void_(), {});
+  auto* v = Global("v", ty.i32(), ast::StorageClass::kPrivate);
+  auto* p = Const("p", ty.pointer(ty.i32(), ast::StorageClass::kPrivate),
+                  AddressOf(v));
+  auto* c = Var("c", ty.i32(), ast::StorageClass::kNone,
+                Call("foo", Expr(Source{{12, 34}}, p)));
+  Func("main", ast::VariableList{}, ty.void_(),
+       {
+           Decl(p),
+           Decl(c),
+       },
+       {
+           Stage(ast::PipelineStage::kFragment),
+       });
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: expected an address-of expression of a variable "
+            "identifier expression or a function parameter");
+}
+
 }  // namespace
 }  // namespace resolver
 }  // namespace tint
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 374f99b..e485c05 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -2578,7 +2578,21 @@
     }
   }
 
-  // Validate number of arguments match number of parameters
+  function_calls_.emplace(call,
+                          FunctionCallInfo{callee_func, current_statement_});
+  SetExprInfo(call, callee_func->return_type, callee_func->return_type_name);
+
+  if (!ValidateFunctionCall(call, callee_func)) {
+    return false;
+  }
+  return true;
+}
+
+bool Resolver::ValidateFunctionCall(const ast::CallExpression* call,
+                                    const FunctionInfo* callee_func) {
+  auto* ident = call->func();
+  auto name = builder_->Symbols().NameFor(ident->symbol());
+
   if (call->params().size() != callee_func->parameters.size()) {
     bool more = call->params().size() > callee_func->parameters.size();
     AddError("too " + (more ? std::string("many") : std::string("few")) +
@@ -2589,7 +2603,6 @@
     return false;
   }
 
-  // Validate arguments match parameter types
   for (size_t i = 0; i < call->params().size(); ++i) {
     const VariableInfo* param = callee_func->parameters[i];
     const ast::Expression* arg_expr = call->params()[i];
@@ -2603,12 +2616,48 @@
                arg_expr->source());
       return false;
     }
+
+    if (param->declaration->type()->Is<ast::Pointer>()) {
+      auto is_valid = false;
+      if (auto* ident_expr = arg_expr->As<ast::IdentifierExpression>()) {
+        VariableInfo* var;
+        if (!variable_stack_.get(ident_expr->symbol(), &var)) {
+          TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier";
+          return false;
+        }
+        if (var->kind == VariableKind::kParameter) {
+          is_valid = true;
+        }
+      } else if (auto* unary = arg_expr->As<ast::UnaryOpExpression>()) {
+        if (unary->op() == ast::UnaryOp::kAddressOf) {
+          if (auto* ident_unary =
+                  unary->expr()->As<ast::IdentifierExpression>()) {
+            VariableInfo* var;
+            if (!variable_stack_.get(ident_unary->symbol(), &var)) {
+              TINT_ICE(Resolver, diagnostics_)
+                  << "failed to resolve identifier";
+              return false;
+            }
+            if (var->declaration->is_const()) {
+              TINT_ICE(Resolver, diagnostics_)
+                  << "Resolver::FunctionCall() encountered an address-of "
+                     "expression of a constant identifier expression";
+              return false;
+            }
+            is_valid = true;
+          }
+        }
+      }
+
+      if (!is_valid) {
+        AddError(
+            "expected an address-of expression of a variable identifier "
+            "expression or a function parameter",
+            arg_expr->source());
+        return false;
+      }
+    }
   }
-
-  function_calls_.emplace(call,
-                          FunctionCallInfo{callee_func, current_statement_});
-  SetExprInfo(call, callee_func->return_type, callee_func->return_type_name);
-
   return true;
 }
 
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index ad7e56d..0cbcc6f 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -287,6 +287,8 @@
   bool ValidateCallStatement(ast::CallStatement* stmt);
   bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info);
   bool ValidateFunction(const ast::Function* func, const FunctionInfo* info);
+  bool ValidateFunctionCall(const ast::CallExpression* call,
+                            const FunctionInfo* info);
   bool ValidateGlobalVariable(const VariableInfo* var);
   bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco,
                                      const sem::Type* storage_type);
diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc
index 92bd17c..c34fd7d 100644
--- a/src/resolver/validation_test.cc
+++ b/src/resolver/validation_test.cc
@@ -811,20 +811,20 @@
 }
 
 TEST_F(ResolverValidationTest, StructMemberDuplicateName) {
-  Structure("S",
-            {Member("a", ty.i32()), Member(Source{{12, 34}}, "a", ty.i32())});
+  Structure("S", {Member(Source{{12, 34}}, "a", ty.i32()),
+                  Member(Source{{56, 78}}, "a", ty.i32())});
   EXPECT_FALSE(r()->Resolve());
-  EXPECT_EQ(
-      r()->error(),
-      "12:34 error: redefinition of 'a'\nnote: previous definition is here");
+  EXPECT_EQ(r()->error(),
+            "56:78 error: redefinition of 'a'\n12:34 note: previous definition "
+            "is here");
 }
 TEST_F(ResolverValidationTest, StructMemberDuplicateNameDifferentTypes) {
-  Structure("S", {Member("a", ty.bool_()),
+  Structure("S", {Member(Source{{12, 34}}, "a", ty.bool_()),
                   Member(Source{{12, 34}}, "a", ty.vec3<f32>())});
   EXPECT_FALSE(r()->Resolve());
-  EXPECT_EQ(
-      r()->error(),
-      "12:34 error: redefinition of 'a'\nnote: previous definition is here");
+  EXPECT_EQ(r()->error(),
+            "12:34 error: redefinition of 'a'\n12:34 note: previous definition "
+            "is here");
 }
 TEST_F(ResolverValidationTest, StructMemberDuplicateNamePass) {
   Structure("S", {Member("a", ty.i32()), Member("b", ty.f32())});
diff --git a/src/transform/inline_pointer_lets_test.cc b/src/transform/inline_pointer_lets_test.cc
index 91818b8..7682f87 100644
--- a/src/transform/inline_pointer_lets_test.cc
+++ b/src/transform/inline_pointer_lets_test.cc
@@ -79,35 +79,6 @@
   EXPECT_EQ(expect, str(got));
 }
 
-TEST_F(InlinePointerLetsTest, Param) {
-  auto* src = R"(
-fn x(p : ptr<function, i32>) -> i32 {
-  return *p;
-}
-
-fn f() {
-  var v : i32;
-  let p : ptr<function, i32> = &v;
-  var r : i32 = x(p);
-}
-)";
-
-  auto* expect = R"(
-fn x(p : ptr<function, i32>) -> i32 {
-  return *(p);
-}
-
-fn f() {
-  var v : i32;
-  var r : i32 = x(&(v));
-}
-)";
-
-  auto got = Run<InlinePointerLets>(src);
-
-  EXPECT_EQ(expect, str(got));
-}
-
 TEST_F(InlinePointerLetsTest, SavedVars) {
   auto* src = R"(
 struct S {
diff --git a/src/writer/hlsl/generator_impl_sanitizer_test.cc b/src/writer/hlsl/generator_impl_sanitizer_test.cc
index 33c2cce..3fb7915 100644
--- a/src/writer/hlsl/generator_impl_sanitizer_test.cc
+++ b/src/writer/hlsl/generator_impl_sanitizer_test.cc
@@ -286,54 +286,6 @@
   EXPECT_EQ(expect, got);
 }
 
-TEST_F(HlslSanitizerTest, InlineParam) {
-  // fn x(p : ptr<function, i32>) -> i32 {
-  //   return *p;
-  // }
-  //
-  // [[stage(fragment)]]
-  // fn main() {
-  //   var v : i32;
-  //   let p : ptr<function, i32> = &v;
-  //   var r : i32 = x(p);
-  // }
-
-  Func("x", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
-       ty.i32(), {Return(Deref("p"))});
-
-  auto* v = Var("v", ty.i32());
-  auto* p = Const("p", ty.pointer(ty.i32(), ast::StorageClass::kFunction),
-                  AddressOf(v));
-  auto* r = Var("r", ty.i32(), ast::StorageClass::kNone, Call("x", p));
-
-  Func("main", ast::VariableList{}, ty.void_(),
-       {
-           Decl(v),
-           Decl(p),
-           Decl(r),
-       },
-       {
-           Stage(ast::PipelineStage::kFragment),
-       });
-
-  GeneratorImpl& gen = SanitizeAndBuild();
-
-  ASSERT_TRUE(gen.Generate()) << gen.error();
-
-  auto got = gen.result();
-  auto* expect = R"(int x(inout int p) {
-  return p;
-}
-
-void main() {
-  int v = 0;
-  int r = x(v);
-  return;
-}
-)";
-  EXPECT_EQ(expect, got);
-}
-
 }  // namespace
 }  // namespace hlsl
 }  // namespace writer