validation: validate function call pointer parameter
Each argument of a function call of pointer type must be one of:
- An address-of expression of a variable identifier expression
- A function parameter
Also added source location to duplicate struct member name unittest
Bug: tint:983
Change-Id: Ic5ab010b2ed76207a1d8d3ef9f66140ea95f7e72
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/58480
Auto-Submit: Sarah Mashayekhi <sarahmashay@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@chromium.org>
Commit-Queue: Sarah Mashayekhi <sarahmashay@google.com>
diff --git a/src/resolver/call_validation_test.cc b/src/resolver/call_validation_test.cc
index d94b70d..0fc2212 100644
--- a/src/resolver/call_validation_test.cc
+++ b/src/resolver/call_validation_test.cc
@@ -128,6 +128,188 @@
"intentional wrap the function call in ignore()");
}
+TEST_F(ResolverCallValidationTest, PointerArgument_VariableIdentExpr) {
+ // fn foo(p: ptr<function, i32>) {}
+ // fn main() {
+ // var z: i32 = 1;
+ // foo(&z);
+ // }
+ auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
+ Func("foo", {param}, ty.void_(), {});
+ Func("main", {}, ty.void_(),
+ ast::StatementList{
+ Decl(Var("z", ty.i32(), Expr(1))),
+ create<ast::CallStatement>(
+ Call("foo", AddressOf(Source{{12, 34}}, Expr("z")))),
+ });
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverCallValidationTest, PointerArgument_ConstIdentExpr) {
+ // fn foo(p: ptr<function, i32>) {}
+ // fn main() {
+ // let z: i32 = 1;
+ // foo(&z);
+ // }
+ auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
+ Func("foo", {param}, ty.void_(), {});
+ Func("main", {}, ty.void_(),
+ ast::StatementList{
+ Decl(Const("z", ty.i32(), Expr(1))),
+ create<ast::CallStatement>(
+ Call("foo", AddressOf(Expr(Source{{12, 34}}, "z")))),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
+}
+
+TEST_F(ResolverCallValidationTest, PointerArgument_NotIdentExprVar) {
+ // struct S { m: i32; };
+ // fn foo(p: ptr<function, i32>) {}
+ // fn main() {
+ // var v: S;
+ // foo(&v.m);
+ // }
+ auto* S = Structure("S", {Member("m", ty.i32())});
+ auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
+ Func("foo", {param}, ty.void_(), {});
+ Func("main", {}, ty.void_(),
+ ast::StatementList{
+ Decl(Var("v", ty.Of(S))),
+ create<ast::CallStatement>(Call(
+ "foo", AddressOf(Source{{12, 34}}, MemberAccessor("v", "m")))),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: expected an address-of expression of a variable "
+ "identifier expression or a function parameter");
+}
+
+TEST_F(ResolverCallValidationTest, PointerArgument_AddressOfMemberAccessor) {
+ // struct S { m: i32; };
+ // fn foo(p: ptr<function, i32>) {}
+ // fn main() {
+ // let v: S = S();
+ // foo(&v.m);
+ // }
+ auto* S = Structure("S", {Member("m", ty.i32())});
+ auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
+ Func("foo", {param}, ty.void_(), {});
+ Func("main", {}, ty.void_(),
+ ast::StatementList{
+ Decl(Const("v", ty.Of(S), Construct(ty.Of(S)))),
+ create<ast::CallStatement>(Call(
+ "foo",
+ AddressOf(Expr(Source{{12, 34}}, MemberAccessor("v", "m"))))),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
+}
+
+TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParam) {
+ // fn foo(p: ptr<function, i32>) {}
+ // fn bar(p: ptr<function, i32>) {
+ // foo(p);
+ // }
+ Func("foo", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
+ ty.void_(), {});
+ Func("bar", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
+ ty.void_(),
+ ast::StatementList{create<ast::CallStatement>(Call("foo", Expr("p")))});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParamWithMain) {
+ // fn foo(p: ptr<function, i32>) {}
+ // fn bar(p: ptr<function, i32>) {
+ // foo(p);
+ // }
+ // [[stage(fragment)]]
+ // fn main() {
+ // var v: i32;
+ // bar(&v);
+ // }
+ Func("foo", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
+ ty.void_(), {});
+ Func("bar", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
+ ty.void_(),
+ ast::StatementList{create<ast::CallStatement>(Call("foo", Expr("p")))});
+ Func("main", ast::VariableList{}, ty.void_(),
+ {
+ Decl(Var("v", ty.i32(), Expr(1))),
+ create<ast::CallStatement>(Call("foo", AddressOf(Expr("v")))),
+ },
+ {
+ Stage(ast::PipelineStage::kFragment),
+ });
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverCallValidationTest, LetPointer) {
+ // fn x(p : ptr<function, i32>) -> i32 {}
+ // [[stage(fragment)]]
+ // fn main() {
+ // var v: i32;
+ // let p: ptr<function, i32> = &v;
+ // var c: i32 = x(p);
+ // }
+ Func("x", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
+ ty.void_(), {});
+ auto* v = Var("v", ty.i32());
+ auto* p = Const("p", ty.pointer(ty.i32(), ast::StorageClass::kFunction),
+ AddressOf(v));
+ auto* c = Var("c", ty.i32(), ast::StorageClass::kNone,
+ Call("x", Expr(Source{{12, 34}}, p)));
+ Func("main", ast::VariableList{}, ty.void_(),
+ {
+ Decl(v),
+ Decl(p),
+ Decl(c),
+ },
+ {
+ Stage(ast::PipelineStage::kFragment),
+ });
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: expected an address-of expression of a variable "
+ "identifier expression or a function parameter");
+}
+
+TEST_F(ResolverCallValidationTest, LetPointerPrivate) {
+ // let p: ptr<private, i32> = &v;
+ // fn foo(p : ptr<private, i32>) -> i32 {}
+ // var v: i32;
+ // [[stage(fragment)]]
+ // fn main() {
+ // var c: i32 = foo(p);
+ // }
+ Func("foo", {Param("p", ty.pointer<i32>(ast::StorageClass::kPrivate))},
+ ty.void_(), {});
+ auto* v = Global("v", ty.i32(), ast::StorageClass::kPrivate);
+ auto* p = Const("p", ty.pointer(ty.i32(), ast::StorageClass::kPrivate),
+ AddressOf(v));
+ auto* c = Var("c", ty.i32(), ast::StorageClass::kNone,
+ Call("foo", Expr(Source{{12, 34}}, p)));
+ Func("main", ast::VariableList{}, ty.void_(),
+ {
+ Decl(p),
+ Decl(c),
+ },
+ {
+ Stage(ast::PipelineStage::kFragment),
+ });
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: expected an address-of expression of a variable "
+ "identifier expression or a function parameter");
+}
+
} // namespace
} // namespace resolver
} // namespace tint
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 374f99b..e485c05 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -2578,7 +2578,21 @@
}
}
- // Validate number of arguments match number of parameters
+ function_calls_.emplace(call,
+ FunctionCallInfo{callee_func, current_statement_});
+ SetExprInfo(call, callee_func->return_type, callee_func->return_type_name);
+
+ if (!ValidateFunctionCall(call, callee_func)) {
+ return false;
+ }
+ return true;
+}
+
+bool Resolver::ValidateFunctionCall(const ast::CallExpression* call,
+ const FunctionInfo* callee_func) {
+ auto* ident = call->func();
+ auto name = builder_->Symbols().NameFor(ident->symbol());
+
if (call->params().size() != callee_func->parameters.size()) {
bool more = call->params().size() > callee_func->parameters.size();
AddError("too " + (more ? std::string("many") : std::string("few")) +
@@ -2589,7 +2603,6 @@
return false;
}
- // Validate arguments match parameter types
for (size_t i = 0; i < call->params().size(); ++i) {
const VariableInfo* param = callee_func->parameters[i];
const ast::Expression* arg_expr = call->params()[i];
@@ -2603,12 +2616,48 @@
arg_expr->source());
return false;
}
+
+ if (param->declaration->type()->Is<ast::Pointer>()) {
+ auto is_valid = false;
+ if (auto* ident_expr = arg_expr->As<ast::IdentifierExpression>()) {
+ VariableInfo* var;
+ if (!variable_stack_.get(ident_expr->symbol(), &var)) {
+ TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier";
+ return false;
+ }
+ if (var->kind == VariableKind::kParameter) {
+ is_valid = true;
+ }
+ } else if (auto* unary = arg_expr->As<ast::UnaryOpExpression>()) {
+ if (unary->op() == ast::UnaryOp::kAddressOf) {
+ if (auto* ident_unary =
+ unary->expr()->As<ast::IdentifierExpression>()) {
+ VariableInfo* var;
+ if (!variable_stack_.get(ident_unary->symbol(), &var)) {
+ TINT_ICE(Resolver, diagnostics_)
+ << "failed to resolve identifier";
+ return false;
+ }
+ if (var->declaration->is_const()) {
+ TINT_ICE(Resolver, diagnostics_)
+ << "Resolver::FunctionCall() encountered an address-of "
+ "expression of a constant identifier expression";
+ return false;
+ }
+ is_valid = true;
+ }
+ }
+ }
+
+ if (!is_valid) {
+ AddError(
+ "expected an address-of expression of a variable identifier "
+ "expression or a function parameter",
+ arg_expr->source());
+ return false;
+ }
+ }
}
-
- function_calls_.emplace(call,
- FunctionCallInfo{callee_func, current_statement_});
- SetExprInfo(call, callee_func->return_type, callee_func->return_type_name);
-
return true;
}
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index ad7e56d..0cbcc6f 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -287,6 +287,8 @@
bool ValidateCallStatement(ast::CallStatement* stmt);
bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info);
bool ValidateFunction(const ast::Function* func, const FunctionInfo* info);
+ bool ValidateFunctionCall(const ast::CallExpression* call,
+ const FunctionInfo* info);
bool ValidateGlobalVariable(const VariableInfo* var);
bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco,
const sem::Type* storage_type);
diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc
index 92bd17c..c34fd7d 100644
--- a/src/resolver/validation_test.cc
+++ b/src/resolver/validation_test.cc
@@ -811,20 +811,20 @@
}
TEST_F(ResolverValidationTest, StructMemberDuplicateName) {
- Structure("S",
- {Member("a", ty.i32()), Member(Source{{12, 34}}, "a", ty.i32())});
+ Structure("S", {Member(Source{{12, 34}}, "a", ty.i32()),
+ Member(Source{{56, 78}}, "a", ty.i32())});
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(
- r()->error(),
- "12:34 error: redefinition of 'a'\nnote: previous definition is here");
+ EXPECT_EQ(r()->error(),
+ "56:78 error: redefinition of 'a'\n12:34 note: previous definition "
+ "is here");
}
TEST_F(ResolverValidationTest, StructMemberDuplicateNameDifferentTypes) {
- Structure("S", {Member("a", ty.bool_()),
+ Structure("S", {Member(Source{{12, 34}}, "a", ty.bool_()),
Member(Source{{12, 34}}, "a", ty.vec3<f32>())});
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(
- r()->error(),
- "12:34 error: redefinition of 'a'\nnote: previous definition is here");
+ EXPECT_EQ(r()->error(),
+ "12:34 error: redefinition of 'a'\n12:34 note: previous definition "
+ "is here");
}
TEST_F(ResolverValidationTest, StructMemberDuplicateNamePass) {
Structure("S", {Member("a", ty.i32()), Member("b", ty.f32())});
diff --git a/src/transform/inline_pointer_lets_test.cc b/src/transform/inline_pointer_lets_test.cc
index 91818b8..7682f87 100644
--- a/src/transform/inline_pointer_lets_test.cc
+++ b/src/transform/inline_pointer_lets_test.cc
@@ -79,35 +79,6 @@
EXPECT_EQ(expect, str(got));
}
-TEST_F(InlinePointerLetsTest, Param) {
- auto* src = R"(
-fn x(p : ptr<function, i32>) -> i32 {
- return *p;
-}
-
-fn f() {
- var v : i32;
- let p : ptr<function, i32> = &v;
- var r : i32 = x(p);
-}
-)";
-
- auto* expect = R"(
-fn x(p : ptr<function, i32>) -> i32 {
- return *(p);
-}
-
-fn f() {
- var v : i32;
- var r : i32 = x(&(v));
-}
-)";
-
- auto got = Run<InlinePointerLets>(src);
-
- EXPECT_EQ(expect, str(got));
-}
-
TEST_F(InlinePointerLetsTest, SavedVars) {
auto* src = R"(
struct S {
diff --git a/src/writer/hlsl/generator_impl_sanitizer_test.cc b/src/writer/hlsl/generator_impl_sanitizer_test.cc
index 33c2cce..3fb7915 100644
--- a/src/writer/hlsl/generator_impl_sanitizer_test.cc
+++ b/src/writer/hlsl/generator_impl_sanitizer_test.cc
@@ -286,54 +286,6 @@
EXPECT_EQ(expect, got);
}
-TEST_F(HlslSanitizerTest, InlineParam) {
- // fn x(p : ptr<function, i32>) -> i32 {
- // return *p;
- // }
- //
- // [[stage(fragment)]]
- // fn main() {
- // var v : i32;
- // let p : ptr<function, i32> = &v;
- // var r : i32 = x(p);
- // }
-
- Func("x", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
- ty.i32(), {Return(Deref("p"))});
-
- auto* v = Var("v", ty.i32());
- auto* p = Const("p", ty.pointer(ty.i32(), ast::StorageClass::kFunction),
- AddressOf(v));
- auto* r = Var("r", ty.i32(), ast::StorageClass::kNone, Call("x", p));
-
- Func("main", ast::VariableList{}, ty.void_(),
- {
- Decl(v),
- Decl(p),
- Decl(r),
- },
- {
- Stage(ast::PipelineStage::kFragment),
- });
-
- GeneratorImpl& gen = SanitizeAndBuild();
-
- ASSERT_TRUE(gen.Generate()) << gen.error();
-
- auto got = gen.result();
- auto* expect = R"(int x(inout int p) {
- return p;
-}
-
-void main() {
- int v = 0;
- int r = x(v);
- return;
-}
-)";
- EXPECT_EQ(expect, got);
-}
-
} // namespace
} // namespace hlsl
} // namespace writer