tint: Allow captured pointers as function args

The updated WGSL validation rule now requires that the memory view of
the argument matches its root identifier.

This allows for code like this:
   let p = &v;
   foo(p);

Fixed: tint:1754, tint:1734
Change-Id: I3239ec84e1c06398a6ce5bebb1e0b28986764bc6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/109221
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/resolver/call_validation_test.cc b/src/tint/resolver/call_validation_test.cc
index 6c64fad..82037af 100644
--- a/src/tint/resolver/call_validation_test.cc
+++ b/src/tint/resolver/call_validation_test.cc
@@ -114,25 +114,7 @@
     EXPECT_TRUE(r()->Resolve()) << r()->error();
 }
 
-TEST_F(ResolverCallValidationTest, PointerArgument_ConstIdentExpr) {
-    // fn foo(p: ptr<function, i32>) {}
-    // fn main() {
-    //   let z: i32 = 1i;
-    //   foo(&z);
-    // }
-    auto* param = Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction));
-    Func("foo", utils::Vector{param}, ty.void_(), utils::Empty);
-    Func("main", utils::Empty, ty.void_(),
-         utils::Vector{
-             Decl(Let("z", ty.i32(), Expr(1_i))),
-             CallStmt(Call("foo", AddressOf(Expr(Source{{12, 34}}, "z")))),
-         });
-
-    EXPECT_FALSE(r()->Resolve());
-    EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
-}
-
-TEST_F(ResolverCallValidationTest, PointerArgument_NotIdentExprVar) {
+TEST_F(ResolverCallValidationTest, PointerArgument_NotWholeVar) {
     // struct S { m: i32; };
     // fn foo(p: ptr<function, i32>) {}
     // fn main() {
@@ -152,30 +134,8 @@
 
     EXPECT_FALSE(r()->Resolve());
     EXPECT_EQ(r()->error(),
-              "12:34 error: expected an address-of expression of a variable "
-              "identifier expression or a function parameter");
-}
-
-TEST_F(ResolverCallValidationTest, PointerArgument_AddressOfMemberAccessor) {
-    // struct S { m: i32; };
-    // fn foo(p: ptr<function, i32>) {}
-    // fn main() {
-    //   let v: S = S();
-    //   foo(&v.m);
-    // }
-    auto* S = Structure("S", utils::Vector{
-                                 Member("m", ty.i32()),
-                             });
-    auto* param = Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction));
-    Func("foo", utils::Vector{param}, ty.void_(), utils::Empty);
-    Func("main", utils::Empty, ty.void_(),
-         utils::Vector{
-             Decl(Let("v", ty.Of(S), Construct(ty.Of(S)))),
-             CallStmt(Call("foo", AddressOf(MemberAccessor(Source{{12, 34}}, "v", "m")))),
-         });
-
-    EXPECT_FALSE(r()->Resolve());
-    EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
+              "12:34 error: arguments of pointer type must not point to a subset of the "
+              "originating variable");
 }
 
 TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParam) {
@@ -235,65 +195,169 @@
     EXPECT_TRUE(r()->Resolve()) << r()->error();
 }
 
-TEST_F(ResolverCallValidationTest, LetPointer) {
-    // fn x(p : ptr<function, i32>) -> i32 {}
-    // @fragment
-    // fn main() {
-    //   var v: i32;
-    //   let p: ptr<function, i32> = &v;
-    //   var c: i32 = x(p);
+TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParam_NotWholeVar) {
+    // fn foo(p: ptr<function, i32>) {}
+    // fn bar(p: ptr<function, array<i32, 4>>) {
+    //   foo(&(*p)[0]);
     // }
-    Func("x",
+    Func("foo",
          utils::Vector{
              Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction)),
          },
          ty.void_(), utils::Empty);
-    auto* v = Var("v", ty.i32());
-    auto* p = Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kFunction), AddressOf(v));
-    auto* c = Var("c", ty.i32(), Call("x", Expr(Source{{12, 34}}, p)));
+    Func("bar",
+         utils::Vector{
+             Param("p", ty.pointer(ty.array<i32, 4>(), ast::AddressSpace::kFunction)),
+         },
+         ty.void_(),
+         utils::Vector{
+             CallStmt(Call("foo", AddressOf(Source{{12, 34}}, IndexAccessor(Deref("p"), 0_a)))),
+         });
+
+    EXPECT_FALSE(r()->Resolve());
+    EXPECT_EQ(r()->error(),
+              "12:34 error: arguments of pointer type must not point to a subset of the "
+              "originating variable");
+}
+
+TEST_F(ResolverCallValidationTest, LetPointer) {
+    // fn foo(p : ptr<function, i32>) {}
+    // @fragment
+    // fn main() {
+    //   var v: i32;
+    //   let p: ptr<function, i32> = &v;
+    //   foo(p);
+    // }
+    Func("foo",
+         utils::Vector{
+             Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction)),
+         },
+         ty.void_(), utils::Empty);
     Func("main", utils::Empty, ty.void_(),
          utils::Vector{
-             Decl(v),
-             Decl(p),
-             Decl(c),
+             Decl(Var("v", ty.i32())),
+             Decl(Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kFunction), AddressOf("v"))),
+             CallStmt(Call("foo", Expr(Source{{12, 34}}, "p"))),
          },
          utils::Vector{
              Stage(ast::PipelineStage::kFragment),
          });
-    EXPECT_FALSE(r()->Resolve());
-    EXPECT_EQ(r()->error(),
-              "12:34 error: expected an address-of expression of a variable "
-              "identifier expression or a function parameter");
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
 }
 
 TEST_F(ResolverCallValidationTest, LetPointerPrivate) {
-    // let p: ptr<private, i32> = &v;
-    // fn foo(p : ptr<private, i32>) -> i32 {}
-    // var v: i32;
+    // fn foo(p : ptr<private, i32>) {}
+    // var<private> v: i32;
     // @fragment
     // fn main() {
-    //   var c: i32 = foo(p);
+    //   let p: ptr<private, i32> = &v;
+    //   foo(p);
     // }
     Func("foo",
          utils::Vector{
              Param("p", ty.pointer<i32>(ast::AddressSpace::kPrivate)),
          },
          ty.void_(), utils::Empty);
-    auto* v = GlobalVar("v", ty.i32(), ast::AddressSpace::kPrivate);
-    auto* p = Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kPrivate), AddressOf(v));
-    auto* c = Var("c", ty.i32(), Call("foo", Expr(Source{{12, 34}}, p)));
+    GlobalVar("v", ty.i32(), ast::AddressSpace::kPrivate);
     Func("main", utils::Empty, ty.void_(),
          utils::Vector{
-             Decl(p),
-             Decl(c),
+             Decl(Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kPrivate), AddressOf("v"))),
+             CallStmt(Call("foo", Expr(Source{{12, 34}}, "p"))),
+         },
+         utils::Vector{
+             Stage(ast::PipelineStage::kFragment),
+         });
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverCallValidationTest, LetPointer_NotWholeVar) {
+    // fn foo(p : ptr<function, i32>) {}
+    // @fragment
+    // fn main() {
+    //   var v: array<i32, 4>;
+    //   let p: ptr<function, i32> = &(v[0]);
+    //   x(p);
+    // }
+    Func("foo",
+         utils::Vector{
+             Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction)),
+         },
+         ty.void_(), utils::Empty);
+    Func("main", utils::Empty, ty.void_(),
+         utils::Vector{
+             Decl(Var("v", ty.array<i32, 4>())),
+             Decl(Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kFunction),
+                      AddressOf(IndexAccessor("v", 0_a)))),
+             CallStmt(Call("foo", Expr(Source{{12, 34}}, "p"))),
          },
          utils::Vector{
              Stage(ast::PipelineStage::kFragment),
          });
     EXPECT_FALSE(r()->Resolve());
     EXPECT_EQ(r()->error(),
-              "12:34 error: expected an address-of expression of a variable "
-              "identifier expression or a function parameter");
+              "12:34 error: arguments of pointer type must not point to a subset of the "
+              "originating variable");
+}
+
+TEST_F(ResolverCallValidationTest, ComplexPointerChain) {
+    // fn foo(p : ptr<function, array<i32, 4>>) {}
+    // @fragment
+    // fn main() {
+    //   var v: array<i32, 4>;
+    //   let p1 = &v;
+    //   let p2 = p1;
+    //   let p3 = &*p2;
+    //   foo(&*p);
+    // }
+    Func("foo",
+         utils::Vector{
+             Param("p", ty.pointer(ty.array<i32, 4>(), ast::AddressSpace::kFunction)),
+         },
+         ty.void_(), utils::Empty);
+    Func("main", utils::Empty, ty.void_(),
+         utils::Vector{
+             Decl(Var("v", ty.array<i32, 4>())),
+             Decl(Let("p1", AddressOf("v"))),
+             Decl(Let("p2", Expr("p1"))),
+             Decl(Let("p3", AddressOf(Deref("p2")))),
+             CallStmt(Call("foo", AddressOf(Source{{12, 34}}, Deref("p3")))),
+         },
+         utils::Vector{
+             Stage(ast::PipelineStage::kFragment),
+         });
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverCallValidationTest, ComplexPointerChain_NotWholeVar) {
+    // fn foo(p : ptr<function, i32>) {}
+    // @fragment
+    // fn main() {
+    //   var v: array<i32, 4>;
+    //   let p1 = &v;
+    //   let p2 = p1;
+    //   let p3 = &(*p2)[0];
+    //   foo(&*p);
+    // }
+    Func("foo",
+         utils::Vector{
+             Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction)),
+         },
+         ty.void_(), utils::Empty);
+    Func("main", utils::Empty, ty.void_(),
+         utils::Vector{
+             Decl(Var("v", ty.array<i32, 4>())),
+             Decl(Let("p1", AddressOf("v"))),
+             Decl(Let("p2", Expr("p1"))),
+             Decl(Let("p3", AddressOf(IndexAccessor(Deref("p2"), 0_a)))),
+             CallStmt(Call("foo", AddressOf(Source{{12, 34}}, Deref("p3")))),
+         },
+         utils::Vector{
+             Stage(ast::PipelineStage::kFragment),
+         });
+    EXPECT_FALSE(r()->Resolve());
+    EXPECT_EQ(r()->error(),
+              "12:34 error: arguments of pointer type must not point to a subset of the "
+              "originating variable");
 }
 
 TEST_F(ResolverCallValidationTest, CallVariable) {
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc
index f96af05..f285b5f 100644
--- a/src/tint/resolver/validator.cc
+++ b/src/tint/resolver/validator.cc
@@ -1799,35 +1799,28 @@
         }
 
         if (param_type->Is<sem::Pointer>()) {
-            auto is_valid = false;
-            if (auto* ident_expr = arg_expr->As<ast::IdentifierExpression>()) {
-                auto* var = sem_.ResolvedSymbol<sem::Variable>(ident_expr);
-                if (!var) {
-                    TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier";
-                    return false;
-                }
-                if (var->Is<sem::Parameter>()) {
-                    is_valid = true;
-                }
-            } else if (auto* unary = arg_expr->As<ast::UnaryOpExpression>()) {
-                if (unary->op == ast::UnaryOp::kAddressOf) {
-                    if (auto* ident_unary = unary->expr->As<ast::IdentifierExpression>()) {
-                        auto* var = sem_.ResolvedSymbol<sem::Variable>(ident_unary);
-                        if (!var) {
-                            TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier";
-                            return false;
-                        }
-                        is_valid = true;
-                    }
-                }
+            // https://gpuweb.github.io/gpuweb/wgsl/#function-restriction
+            // Each argument of pointer type to a user-defined function must have the same memory
+            // view as its root identifier.
+            // We can validate this by just comparing the store type of the argument with that of
+            // its root identifier, as these will match iff the memory view is the same.
+            auto* arg_store_type = arg_type->As<sem::Pointer>()->StoreType();
+            auto* root = call->Arguments()[i]->RootIdentifier();
+            auto* root_ptr_ty = root->Type()->As<sem::Pointer>();
+            auto* root_ref_ty = root->Type()->As<sem::Reference>();
+            TINT_ASSERT(Resolver, root_ptr_ty || root_ref_ty);
+            const sem::Type* root_store_type;
+            if (root_ptr_ty) {
+                root_store_type = root_ptr_ty->StoreType();
+            } else {
+                root_store_type = root_ref_ty->StoreType();
             }
-
-            if (!is_valid &&
+            if (root_store_type != arg_store_type &&
                 IsValidationEnabled(param->Declaration()->attributes,
                                     ast::DisabledValidation::kIgnoreInvalidPointerArgument)) {
                 AddError(
-                    "expected an address-of expression of a variable identifier expression or a "
-                    "function parameter",
+                    "arguments of pointer type must not point to a subset of the originating "
+                    "variable",
                     arg_expr->source);
                 return false;
             }