[tint][ir][val] Check function parameter types
Fixes: 369794226
Change-Id: I825abb6eba194c403eb06396d7951bfe11105f70
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/208854
Commit-Queue: James Price <jrprice@google.com>
Auto-Submit: Ryan Harrison <rharrison@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ryan Harrison <rharrison@chromium.org>
diff --git a/src/tint/lang/core/ir/transform/preserve_padding_test.cc b/src/tint/lang/core/ir/transform/preserve_padding_test.cc
index 188029c..cb0afb0 100644
--- a/src/tint/lang/core/ir/transform/preserve_padding_test.cc
+++ b/src/tint/lang/core/ir/transform/preserve_padding_test.cc
@@ -263,7 +263,7 @@
}
TEST_F(IR_PreservePaddingTest, NoModify_ArrayWithoutPadding) {
- auto* arr = ty.array<vec4<f32>>();
+ auto* arr = ty.array<vec4<f32>, 4>();
auto* buffer = b.Var("buffer", ty.ptr(storage, arr));
buffer->SetBindingPoint(0, 0);
mod.root_block->Append(buffer);
@@ -278,10 +278,10 @@
auto* src = R"(
$B1: { # root
- %buffer:ptr<storage, array<vec4<f32>>, read_write> = var @binding_point(0, 0)
+ %buffer:ptr<storage, array<vec4<f32>, 4>, read_write> = var @binding_point(0, 0)
}
-%foo = func(%value:array<vec4<f32>>):void {
+%foo = func(%value:array<vec4<f32>, 4>):void {
$B2: {
store %buffer, %value
ret
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index 2071427..8c6ad6d 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -129,11 +129,21 @@
return false;
}
-/// @returns true is @p attr contains both a location and builtin decoration
+/// @returns true if @p attr contains both a location and builtin decoration
bool HasLocationAndBuiltin(const tint::core::IOAttributes& attr) {
return attr.builtin.has_value() && attr.location.has_value();
}
+/// @returns true if @p ty meets the basic function parameter rules (i.e. one of constructible,
+/// pointer, sampler or texture).
+///
+/// Note: Does not handle corner cases like if certain capabilities are
+/// enabled.
+bool IsValidFunctionParamType(const core::type::Type* ty) {
+ return ty->IsConstructible() || ty->Is<type::Pointer>() || ty->Is<type::Texture>() ||
+ ty->Is<type::Sampler>();
+}
+
/// The core IR validator.
class Validator {
public:
@@ -1176,6 +1186,23 @@
return;
}
+ // References not allowed on function signatures even with Capability::kAllowRefTypes.
+ CheckType(
+ param->Type(), [&]() -> diag::Diagnostic& { return AddError(param); },
+ Capabilities{Capability::kAllowRefTypes});
+
+ if (!IsValidFunctionParamType(param->Type())) {
+ auto struct_ty = param->Type()->As<core::type::Struct>();
+ if (!capabilities_.Contains(Capability::kAllowPointersInStructures) || !struct_ty ||
+ struct_ty->Members().Any([](const core::type::StructMember* m) {
+ return !IsValidFunctionParamType(m->Type());
+ })) {
+ AddError(param) << "function parameter type must be constructible, a pointer, a "
+ "texture, or a sampler";
+ return;
+ }
+ }
+
if (HasLocationAndBuiltin(param->Attributes())) {
AddError(param) << "a builtin and location cannot be both declared for a param";
return;
@@ -1191,11 +1218,6 @@
}
}
- // References not allowed on function signatures even with Capability::kAllowRefTypes.
- CheckType(
- param->Type(), [&]() -> diag::Diagnostic& { return AddError(param); },
- Capabilities{Capability::kAllowRefTypes});
-
scope_stack_.Add(param);
}
diff --git a/src/tint/lang/core/ir/validator_test.cc b/src/tint/lang/core/ir/validator_test.cc
index 415a848..33bbf50 100644
--- a/src/tint/lang/core/ir/validator_test.cc
+++ b/src/tint/lang/core/ir/validator_test.cc
@@ -383,6 +383,69 @@
)");
}
+TEST_F(IR_ValidatorTest, Function_ParameterWithConstructibleType) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* p = b.FunctionParam("my_param", ty.u32());
+ f->SetParams({p});
+ f->Block()->Append(b.Return(f));
+
+ auto res = ir::Validate(mod);
+ ASSERT_EQ(res, Success);
+}
+
+TEST_F(IR_ValidatorTest, Function_ParameterWithPointerType) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* p = b.FunctionParam("my_param", ty.ptr<function, i32>());
+ f->SetParams({p});
+ f->Block()->Append(b.Return(f));
+
+ auto res = ir::Validate(mod);
+ ASSERT_EQ(res, Success);
+}
+
+TEST_F(IR_ValidatorTest, Function_ParameterWithTextureType) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* p = b.FunctionParam("my_param", ty.external_texture());
+ f->SetParams({p});
+ f->Block()->Append(b.Return(f));
+
+ auto res = ir::Validate(mod);
+ ASSERT_EQ(res, Success);
+}
+
+TEST_F(IR_ValidatorTest, Function_ParameterWithSamplerType) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* p = b.FunctionParam("my_param", ty.sampler());
+ f->SetParams({p});
+ f->Block()->Append(b.Return(f));
+
+ auto res = ir::Validate(mod);
+ ASSERT_EQ(res, Success);
+}
+
+TEST_F(IR_ValidatorTest, Function_ParameterWithVoidType) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* p = b.FunctionParam("my_param", ty.void_());
+ f->SetParams({p});
+ f->Block()->Append(b.Return(f));
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_EQ(
+ res.Failure().reason.Str(),
+ R"(:1:17 error: function parameter type must be constructible, a pointer, a texture, or a sampler
+%my_func = func(%my_param:void):void {
+ ^^^^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func(%my_param:void):void {
+ $B1: {
+ ret
+ }
+}
+)");
+}
+
TEST_F(IR_ValidatorTest, Function_Return_BothLocationAndBuiltin) {
auto* f = b.Function("my_func", ty.f32());
@@ -7455,10 +7518,9 @@
ty.Struct(mod.symbols.New("S"), {
{mod.symbols.New("a"), ty.ptr<private_, i32>()},
});
+ mod.root_block->Append(b.Var("my_struct", private_, str_ty));
auto* fn = b.Function("F", ty.void_());
- auto* param = b.FunctionParam("param", str_ty);
- fn->SetParams({param});
b.Append(fn->Block(), [&] { b.Return(fn); });
auto res = ir::Validate(mod);
diff --git a/src/tint/lang/hlsl/ir/member_builtin_call_test.cc b/src/tint/lang/hlsl/ir/member_builtin_call_test.cc
index e886929..b99af8f 100644
--- a/src/tint/lang/hlsl/ir/member_builtin_call_test.cc
+++ b/src/tint/lang/hlsl/ir/member_builtin_call_test.cc
@@ -71,12 +71,12 @@
}
TEST_F(IR_HlslMemberBuiltinCallTest, DoesNotMatchNonMemberFunction) {
- auto* buf = ty.Get<hlsl::type::ByteAddressBuffer>(core::Access::kRead);
-
- auto* t = b.FunctionParam("t", buf);
+ auto* buf_ty = ty.Get<hlsl::type::ByteAddressBuffer>(core::Access::kRead);
+ auto* t = b.Var("t", buf_ty);
+ t->SetBindingPoint(0, 0);
+ mod.root_block->Append(t);
auto* func = b.Function("foo", ty.u32());
- func->SetParams({t});
b.Append(func->Block(), [&] {
auto* builtin =
b.MemberCall<MemberBuiltinCall>(mod.Types().u32(), BuiltinFn::kAsint, t, 2_u);
@@ -87,18 +87,22 @@
ASSERT_NE(res, Success);
EXPECT_EQ(
res.Failure().reason.Str(),
- R"(:3:17 error: asint: no matching call to 'asint(hlsl.byte_address_buffer<read>, u32)'
+ R"(:7:17 error: asint: no matching call to 'asint(hlsl.byte_address_buffer<read>, u32)'
%3:u32 = %t.asint 2u
^^^^^
-:2:3 note: in block
- $B1: {
+:6:3 note: in block
+ $B2: {
^^^
note: # Disassembly
-%foo = func(%t:hlsl.byte_address_buffer<read>):u32 {
- $B1: {
+$B1: { # root
+ %t:hlsl.byte_address_buffer<read> = var @binding_point(0, 0)
+}
+
+%foo = func():u32 {
+ $B2: {
%3:u32 = %t.asint 2u
ret %3
}
@@ -107,12 +111,12 @@
}
TEST_F(IR_HlslMemberBuiltinCallTest, DoesNotMatchIncorrectType) {
- auto* buf = ty.Get<hlsl::type::ByteAddressBuffer>(core::Access::kRead);
-
- auto* t = b.FunctionParam("t", buf);
+ auto* buf_ty = ty.Get<hlsl::type::ByteAddressBuffer>(core::Access::kRead);
+ auto* t = b.Var("t", buf_ty);
+ t->SetBindingPoint(0, 0);
+ mod.root_block->Append(t);
auto* func = b.Function("foo", ty.u32());
- func->SetParams({t});
b.Append(func->Block(), [&] {
auto* builtin =
b.MemberCall<MemberBuiltinCall>(mod.Types().u32(), BuiltinFn::kStore, t, 2_u, 2_u);
@@ -123,7 +127,7 @@
ASSERT_NE(res, Success);
EXPECT_EQ(
res.Failure().reason.Str(),
- R"(:3:17 error: Store: no matching call to 'Store(hlsl.byte_address_buffer<read>, u32, u32)'
+ R"(:7:17 error: Store: no matching call to 'Store(hlsl.byte_address_buffer<read>, u32, u32)'
1 candidate function:
• 'Store(byte_address_buffer<write' or 'read_write> ✗ , offset: u32 ✓ , value: u32 ✓ )'
@@ -131,13 +135,17 @@
%3:u32 = %t.Store 2u, 2u
^^^^^
-:2:3 note: in block
- $B1: {
+:6:3 note: in block
+ $B2: {
^^^
note: # Disassembly
-%foo = func(%t:hlsl.byte_address_buffer<read>):u32 {
- $B1: {
+$B1: { # root
+ %t:hlsl.byte_address_buffer<read> = var @binding_point(0, 0)
+}
+
+%foo = func():u32 {
+ $B2: {
%3:u32 = %t.Store 2u, 2u
ret %3
}
diff --git a/src/tint/utils/containers/vector.h b/src/tint/utils/containers/vector.h
index 23b5b6d..74f7129 100644
--- a/src/tint/utils/containers/vector.h
+++ b/src/tint/utils/containers/vector.h
@@ -1161,6 +1161,13 @@
return hash;
}
+ /// @returns true if the predicate function returns true for any of the elements of the vector
+ /// @param pred a function-like with the signature `bool(T)`
+ template <typename PREDICATE>
+ bool Any(PREDICATE&& pred) const {
+ return std::any_of(begin(), end(), std::forward<PREDICATE>(pred));
+ }
+
private:
/// Friend class
template <typename, size_t>