tint: Allow captured pointers as function args
The updated WGSL validation rule now requires that the memory view of
the argument matches its root identifier.
This allows for code like this:
let p = &v;
foo(p);
Fixed: tint:1754, tint:1734
Change-Id: I3239ec84e1c06398a6ce5bebb1e0b28986764bc6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/109221
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/resolver/call_validation_test.cc b/src/tint/resolver/call_validation_test.cc
index 6c64fad..82037af 100644
--- a/src/tint/resolver/call_validation_test.cc
+++ b/src/tint/resolver/call_validation_test.cc
@@ -114,25 +114,7 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
-TEST_F(ResolverCallValidationTest, PointerArgument_ConstIdentExpr) {
- // fn foo(p: ptr<function, i32>) {}
- // fn main() {
- // let z: i32 = 1i;
- // foo(&z);
- // }
- auto* param = Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction));
- Func("foo", utils::Vector{param}, ty.void_(), utils::Empty);
- Func("main", utils::Empty, ty.void_(),
- utils::Vector{
- Decl(Let("z", ty.i32(), Expr(1_i))),
- CallStmt(Call("foo", AddressOf(Expr(Source{{12, 34}}, "z")))),
- });
-
- EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
-}
-
-TEST_F(ResolverCallValidationTest, PointerArgument_NotIdentExprVar) {
+TEST_F(ResolverCallValidationTest, PointerArgument_NotWholeVar) {
// struct S { m: i32; };
// fn foo(p: ptr<function, i32>) {}
// fn main() {
@@ -152,30 +134,8 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: expected an address-of expression of a variable "
- "identifier expression or a function parameter");
-}
-
-TEST_F(ResolverCallValidationTest, PointerArgument_AddressOfMemberAccessor) {
- // struct S { m: i32; };
- // fn foo(p: ptr<function, i32>) {}
- // fn main() {
- // let v: S = S();
- // foo(&v.m);
- // }
- auto* S = Structure("S", utils::Vector{
- Member("m", ty.i32()),
- });
- auto* param = Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction));
- Func("foo", utils::Vector{param}, ty.void_(), utils::Empty);
- Func("main", utils::Empty, ty.void_(),
- utils::Vector{
- Decl(Let("v", ty.Of(S), Construct(ty.Of(S)))),
- CallStmt(Call("foo", AddressOf(MemberAccessor(Source{{12, 34}}, "v", "m")))),
- });
-
- EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
+ "12:34 error: arguments of pointer type must not point to a subset of the "
+ "originating variable");
}
TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParam) {
@@ -235,65 +195,169 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
-TEST_F(ResolverCallValidationTest, LetPointer) {
- // fn x(p : ptr<function, i32>) -> i32 {}
- // @fragment
- // fn main() {
- // var v: i32;
- // let p: ptr<function, i32> = &v;
- // var c: i32 = x(p);
+TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParam_NotWholeVar) {
+ // fn foo(p: ptr<function, i32>) {}
+ // fn bar(p: ptr<function, array<i32, 4>>) {
+ // foo(&(*p)[0]);
// }
- Func("x",
+ Func("foo",
utils::Vector{
Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction)),
},
ty.void_(), utils::Empty);
- auto* v = Var("v", ty.i32());
- auto* p = Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kFunction), AddressOf(v));
- auto* c = Var("c", ty.i32(), Call("x", Expr(Source{{12, 34}}, p)));
+ Func("bar",
+ utils::Vector{
+ Param("p", ty.pointer(ty.array<i32, 4>(), ast::AddressSpace::kFunction)),
+ },
+ ty.void_(),
+ utils::Vector{
+ CallStmt(Call("foo", AddressOf(Source{{12, 34}}, IndexAccessor(Deref("p"), 0_a)))),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: arguments of pointer type must not point to a subset of the "
+ "originating variable");
+}
+
+TEST_F(ResolverCallValidationTest, LetPointer) {
+ // fn foo(p : ptr<function, i32>) {}
+ // @fragment
+ // fn main() {
+ // var v: i32;
+ // let p: ptr<function, i32> = &v;
+ // foo(p);
+ // }
+ Func("foo",
+ utils::Vector{
+ Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction)),
+ },
+ ty.void_(), utils::Empty);
Func("main", utils::Empty, ty.void_(),
utils::Vector{
- Decl(v),
- Decl(p),
- Decl(c),
+ Decl(Var("v", ty.i32())),
+ Decl(Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kFunction), AddressOf("v"))),
+ CallStmt(Call("foo", Expr(Source{{12, 34}}, "p"))),
},
utils::Vector{
Stage(ast::PipelineStage::kFragment),
});
- EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(),
- "12:34 error: expected an address-of expression of a variable "
- "identifier expression or a function parameter");
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverCallValidationTest, LetPointerPrivate) {
- // let p: ptr<private, i32> = &v;
- // fn foo(p : ptr<private, i32>) -> i32 {}
- // var v: i32;
+ // fn foo(p : ptr<private, i32>) {}
+ // var<private> v: i32;
// @fragment
// fn main() {
- // var c: i32 = foo(p);
+ // let p: ptr<private, i32> = &v;
+ // foo(p);
// }
Func("foo",
utils::Vector{
Param("p", ty.pointer<i32>(ast::AddressSpace::kPrivate)),
},
ty.void_(), utils::Empty);
- auto* v = GlobalVar("v", ty.i32(), ast::AddressSpace::kPrivate);
- auto* p = Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kPrivate), AddressOf(v));
- auto* c = Var("c", ty.i32(), Call("foo", Expr(Source{{12, 34}}, p)));
+ GlobalVar("v", ty.i32(), ast::AddressSpace::kPrivate);
Func("main", utils::Empty, ty.void_(),
utils::Vector{
- Decl(p),
- Decl(c),
+ Decl(Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kPrivate), AddressOf("v"))),
+ CallStmt(Call("foo", Expr(Source{{12, 34}}, "p"))),
+ },
+ utils::Vector{
+ Stage(ast::PipelineStage::kFragment),
+ });
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverCallValidationTest, LetPointer_NotWholeVar) {
+ // fn foo(p : ptr<function, i32>) {}
+ // @fragment
+ // fn main() {
+ // var v: array<i32, 4>;
+ // let p: ptr<function, i32> = &(v[0]);
+ // x(p);
+ // }
+ Func("foo",
+ utils::Vector{
+ Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction)),
+ },
+ ty.void_(), utils::Empty);
+ Func("main", utils::Empty, ty.void_(),
+ utils::Vector{
+ Decl(Var("v", ty.array<i32, 4>())),
+ Decl(Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kFunction),
+ AddressOf(IndexAccessor("v", 0_a)))),
+ CallStmt(Call("foo", Expr(Source{{12, 34}}, "p"))),
},
utils::Vector{
Stage(ast::PipelineStage::kFragment),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: expected an address-of expression of a variable "
- "identifier expression or a function parameter");
+ "12:34 error: arguments of pointer type must not point to a subset of the "
+ "originating variable");
+}
+
+TEST_F(ResolverCallValidationTest, ComplexPointerChain) {
+ // fn foo(p : ptr<function, array<i32, 4>>) {}
+ // @fragment
+ // fn main() {
+ // var v: array<i32, 4>;
+ // let p1 = &v;
+ // let p2 = p1;
+ // let p3 = &*p2;
+ // foo(&*p);
+ // }
+ Func("foo",
+ utils::Vector{
+ Param("p", ty.pointer(ty.array<i32, 4>(), ast::AddressSpace::kFunction)),
+ },
+ ty.void_(), utils::Empty);
+ Func("main", utils::Empty, ty.void_(),
+ utils::Vector{
+ Decl(Var("v", ty.array<i32, 4>())),
+ Decl(Let("p1", AddressOf("v"))),
+ Decl(Let("p2", Expr("p1"))),
+ Decl(Let("p3", AddressOf(Deref("p2")))),
+ CallStmt(Call("foo", AddressOf(Source{{12, 34}}, Deref("p3")))),
+ },
+ utils::Vector{
+ Stage(ast::PipelineStage::kFragment),
+ });
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverCallValidationTest, ComplexPointerChain_NotWholeVar) {
+ // fn foo(p : ptr<function, i32>) {}
+ // @fragment
+ // fn main() {
+ // var v: array<i32, 4>;
+ // let p1 = &v;
+ // let p2 = p1;
+ // let p3 = &(*p2)[0];
+ // foo(&*p);
+ // }
+ Func("foo",
+ utils::Vector{
+ Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction)),
+ },
+ ty.void_(), utils::Empty);
+ Func("main", utils::Empty, ty.void_(),
+ utils::Vector{
+ Decl(Var("v", ty.array<i32, 4>())),
+ Decl(Let("p1", AddressOf("v"))),
+ Decl(Let("p2", Expr("p1"))),
+ Decl(Let("p3", AddressOf(IndexAccessor(Deref("p2"), 0_a)))),
+ CallStmt(Call("foo", AddressOf(Source{{12, 34}}, Deref("p3")))),
+ },
+ utils::Vector{
+ Stage(ast::PipelineStage::kFragment),
+ });
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: arguments of pointer type must not point to a subset of the "
+ "originating variable");
}
TEST_F(ResolverCallValidationTest, CallVariable) {
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc
index f96af05..f285b5f 100644
--- a/src/tint/resolver/validator.cc
+++ b/src/tint/resolver/validator.cc
@@ -1799,35 +1799,28 @@
}
if (param_type->Is<sem::Pointer>()) {
- auto is_valid = false;
- if (auto* ident_expr = arg_expr->As<ast::IdentifierExpression>()) {
- auto* var = sem_.ResolvedSymbol<sem::Variable>(ident_expr);
- if (!var) {
- TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier";
- return false;
- }
- if (var->Is<sem::Parameter>()) {
- is_valid = true;
- }
- } else if (auto* unary = arg_expr->As<ast::UnaryOpExpression>()) {
- if (unary->op == ast::UnaryOp::kAddressOf) {
- if (auto* ident_unary = unary->expr->As<ast::IdentifierExpression>()) {
- auto* var = sem_.ResolvedSymbol<sem::Variable>(ident_unary);
- if (!var) {
- TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier";
- return false;
- }
- is_valid = true;
- }
- }
+ // https://gpuweb.github.io/gpuweb/wgsl/#function-restriction
+ // Each argument of pointer type to a user-defined function must have the same memory
+ // view as its root identifier.
+ // We can validate this by just comparing the store type of the argument with that of
+ // its root identifier, as these will match iff the memory view is the same.
+ auto* arg_store_type = arg_type->As<sem::Pointer>()->StoreType();
+ auto* root = call->Arguments()[i]->RootIdentifier();
+ auto* root_ptr_ty = root->Type()->As<sem::Pointer>();
+ auto* root_ref_ty = root->Type()->As<sem::Reference>();
+ TINT_ASSERT(Resolver, root_ptr_ty || root_ref_ty);
+ const sem::Type* root_store_type;
+ if (root_ptr_ty) {
+ root_store_type = root_ptr_ty->StoreType();
+ } else {
+ root_store_type = root_ref_ty->StoreType();
}
-
- if (!is_valid &&
+ if (root_store_type != arg_store_type &&
IsValidationEnabled(param->Declaration()->attributes,
ast::DisabledValidation::kIgnoreInvalidPointerArgument)) {
AddError(
- "expected an address-of expression of a variable identifier expression or a "
- "function parameter",
+ "arguments of pointer type must not point to a subset of the originating "
+ "variable",
arg_expr->source);
return false;
}