[tint][ir][val] Check function parameter types

Fixes: 369794226

Change-Id: I825abb6eba194c403eb06396d7951bfe11105f70
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/208854
Commit-Queue: James Price <jrprice@google.com>
Auto-Submit: Ryan Harrison <rharrison@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ryan Harrison <rharrison@chromium.org>
diff --git a/src/tint/lang/core/ir/transform/preserve_padding_test.cc b/src/tint/lang/core/ir/transform/preserve_padding_test.cc
index 188029c..cb0afb0 100644
--- a/src/tint/lang/core/ir/transform/preserve_padding_test.cc
+++ b/src/tint/lang/core/ir/transform/preserve_padding_test.cc
@@ -263,7 +263,7 @@
 }
 
 TEST_F(IR_PreservePaddingTest, NoModify_ArrayWithoutPadding) {
-    auto* arr = ty.array<vec4<f32>>();
+    auto* arr = ty.array<vec4<f32>, 4>();
     auto* buffer = b.Var("buffer", ty.ptr(storage, arr));
     buffer->SetBindingPoint(0, 0);
     mod.root_block->Append(buffer);
@@ -278,10 +278,10 @@
 
     auto* src = R"(
 $B1: {  # root
-  %buffer:ptr<storage, array<vec4<f32>>, read_write> = var @binding_point(0, 0)
+  %buffer:ptr<storage, array<vec4<f32>, 4>, read_write> = var @binding_point(0, 0)
 }
 
-%foo = func(%value:array<vec4<f32>>):void {
+%foo = func(%value:array<vec4<f32>, 4>):void {
   $B2: {
     store %buffer, %value
     ret
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index 2071427..8c6ad6d 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -129,11 +129,21 @@
     return false;
 }
 
-/// @returns true is @p attr contains both a location and builtin decoration
+/// @returns true if @p attr contains both a location and builtin decoration
 bool HasLocationAndBuiltin(const tint::core::IOAttributes& attr) {
     return attr.builtin.has_value() && attr.location.has_value();
 }
 
+/// @returns true if @p ty meets the basic function parameter rules (i.e. one of constructible,
+///          pointer, sampler or texture).
+///
+/// Note: Does not handle corner cases like if certain capabilities are
+/// enabled.
+bool IsValidFunctionParamType(const core::type::Type* ty) {
+    return ty->IsConstructible() || ty->Is<type::Pointer>() || ty->Is<type::Texture>() ||
+           ty->Is<type::Sampler>();
+}
+
 /// The core IR validator.
 class Validator {
   public:
@@ -1176,6 +1186,23 @@
             return;
         }
 
+        // References not allowed on function signatures even with Capability::kAllowRefTypes.
+        CheckType(
+            param->Type(), [&]() -> diag::Diagnostic& { return AddError(param); },
+            Capabilities{Capability::kAllowRefTypes});
+
+        if (!IsValidFunctionParamType(param->Type())) {
+            auto struct_ty = param->Type()->As<core::type::Struct>();
+            if (!capabilities_.Contains(Capability::kAllowPointersInStructures) || !struct_ty ||
+                struct_ty->Members().Any([](const core::type::StructMember* m) {
+                    return !IsValidFunctionParamType(m->Type());
+                })) {
+                AddError(param) << "function parameter type must be constructible, a pointer, a "
+                                   "texture, or a sampler";
+                return;
+            }
+        }
+
         if (HasLocationAndBuiltin(param->Attributes())) {
             AddError(param) << "a builtin and location cannot be both declared for a param";
             return;
@@ -1191,11 +1218,6 @@
             }
         }
 
-        // References not allowed on function signatures even with Capability::kAllowRefTypes.
-        CheckType(
-            param->Type(), [&]() -> diag::Diagnostic& { return AddError(param); },
-            Capabilities{Capability::kAllowRefTypes});
-
         scope_stack_.Add(param);
     }
 
diff --git a/src/tint/lang/core/ir/validator_test.cc b/src/tint/lang/core/ir/validator_test.cc
index 415a848..33bbf50 100644
--- a/src/tint/lang/core/ir/validator_test.cc
+++ b/src/tint/lang/core/ir/validator_test.cc
@@ -383,6 +383,69 @@
 )");
 }
 
+TEST_F(IR_ValidatorTest, Function_ParameterWithConstructibleType) {
+    auto* f = b.Function("my_func", ty.void_());
+    auto* p = b.FunctionParam("my_param", ty.u32());
+    f->SetParams({p});
+    f->Block()->Append(b.Return(f));
+
+    auto res = ir::Validate(mod);
+    ASSERT_EQ(res, Success);
+}
+
+TEST_F(IR_ValidatorTest, Function_ParameterWithPointerType) {
+    auto* f = b.Function("my_func", ty.void_());
+    auto* p = b.FunctionParam("my_param", ty.ptr<function, i32>());
+    f->SetParams({p});
+    f->Block()->Append(b.Return(f));
+
+    auto res = ir::Validate(mod);
+    ASSERT_EQ(res, Success);
+}
+
+TEST_F(IR_ValidatorTest, Function_ParameterWithTextureType) {
+    auto* f = b.Function("my_func", ty.void_());
+    auto* p = b.FunctionParam("my_param", ty.external_texture());
+    f->SetParams({p});
+    f->Block()->Append(b.Return(f));
+
+    auto res = ir::Validate(mod);
+    ASSERT_EQ(res, Success);
+}
+
+TEST_F(IR_ValidatorTest, Function_ParameterWithSamplerType) {
+    auto* f = b.Function("my_func", ty.void_());
+    auto* p = b.FunctionParam("my_param", ty.sampler());
+    f->SetParams({p});
+    f->Block()->Append(b.Return(f));
+
+    auto res = ir::Validate(mod);
+    ASSERT_EQ(res, Success);
+}
+
+TEST_F(IR_ValidatorTest, Function_ParameterWithVoidType) {
+    auto* f = b.Function("my_func", ty.void_());
+    auto* p = b.FunctionParam("my_param", ty.void_());
+    f->SetParams({p});
+    f->Block()->Append(b.Return(f));
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_EQ(
+        res.Failure().reason.Str(),
+        R"(:1:17 error: function parameter type must be constructible, a pointer, a texture, or a sampler
+%my_func = func(%my_param:void):void {
+                ^^^^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func(%my_param:void):void {
+  $B1: {
+    ret
+  }
+}
+)");
+}
+
 TEST_F(IR_ValidatorTest, Function_Return_BothLocationAndBuiltin) {
     auto* f = b.Function("my_func", ty.f32());
 
@@ -7455,10 +7518,9 @@
         ty.Struct(mod.symbols.New("S"), {
                                             {mod.symbols.New("a"), ty.ptr<private_, i32>()},
                                         });
+    mod.root_block->Append(b.Var("my_struct", private_, str_ty));
 
     auto* fn = b.Function("F", ty.void_());
-    auto* param = b.FunctionParam("param", str_ty);
-    fn->SetParams({param});
     b.Append(fn->Block(), [&] { b.Return(fn); });
 
     auto res = ir::Validate(mod);
diff --git a/src/tint/lang/hlsl/ir/member_builtin_call_test.cc b/src/tint/lang/hlsl/ir/member_builtin_call_test.cc
index e886929..b99af8f 100644
--- a/src/tint/lang/hlsl/ir/member_builtin_call_test.cc
+++ b/src/tint/lang/hlsl/ir/member_builtin_call_test.cc
@@ -71,12 +71,12 @@
 }
 
 TEST_F(IR_HlslMemberBuiltinCallTest, DoesNotMatchNonMemberFunction) {
-    auto* buf = ty.Get<hlsl::type::ByteAddressBuffer>(core::Access::kRead);
-
-    auto* t = b.FunctionParam("t", buf);
+    auto* buf_ty = ty.Get<hlsl::type::ByteAddressBuffer>(core::Access::kRead);
+    auto* t = b.Var("t", buf_ty);
+    t->SetBindingPoint(0, 0);
+    mod.root_block->Append(t);
 
     auto* func = b.Function("foo", ty.u32());
-    func->SetParams({t});
     b.Append(func->Block(), [&] {
         auto* builtin =
             b.MemberCall<MemberBuiltinCall>(mod.Types().u32(), BuiltinFn::kAsint, t, 2_u);
@@ -87,18 +87,22 @@
     ASSERT_NE(res, Success);
     EXPECT_EQ(
         res.Failure().reason.Str(),
-        R"(:3:17 error: asint: no matching call to 'asint(hlsl.byte_address_buffer<read>, u32)'
+        R"(:7:17 error: asint: no matching call to 'asint(hlsl.byte_address_buffer<read>, u32)'
 
     %3:u32 = %t.asint 2u
                 ^^^^^
 
-:2:3 note: in block
-  $B1: {
+:6:3 note: in block
+  $B2: {
   ^^^
 
 note: # Disassembly
-%foo = func(%t:hlsl.byte_address_buffer<read>):u32 {
-  $B1: {
+$B1: {  # root
+  %t:hlsl.byte_address_buffer<read> = var @binding_point(0, 0)
+}
+
+%foo = func():u32 {
+  $B2: {
     %3:u32 = %t.asint 2u
     ret %3
   }
@@ -107,12 +111,12 @@
 }
 
 TEST_F(IR_HlslMemberBuiltinCallTest, DoesNotMatchIncorrectType) {
-    auto* buf = ty.Get<hlsl::type::ByteAddressBuffer>(core::Access::kRead);
-
-    auto* t = b.FunctionParam("t", buf);
+    auto* buf_ty = ty.Get<hlsl::type::ByteAddressBuffer>(core::Access::kRead);
+    auto* t = b.Var("t", buf_ty);
+    t->SetBindingPoint(0, 0);
+    mod.root_block->Append(t);
 
     auto* func = b.Function("foo", ty.u32());
-    func->SetParams({t});
     b.Append(func->Block(), [&] {
         auto* builtin =
             b.MemberCall<MemberBuiltinCall>(mod.Types().u32(), BuiltinFn::kStore, t, 2_u, 2_u);
@@ -123,7 +127,7 @@
     ASSERT_NE(res, Success);
     EXPECT_EQ(
         res.Failure().reason.Str(),
-        R"(:3:17 error: Store: no matching call to 'Store(hlsl.byte_address_buffer<read>, u32, u32)'
+        R"(:7:17 error: Store: no matching call to 'Store(hlsl.byte_address_buffer<read>, u32, u32)'
 
 1 candidate function:
  • 'Store(byte_address_buffer<write' or 'read_write>  ✗ , offset: u32  ✓ , value: u32  ✓ )'
@@ -131,13 +135,17 @@
     %3:u32 = %t.Store 2u, 2u
                 ^^^^^
 
-:2:3 note: in block
-  $B1: {
+:6:3 note: in block
+  $B2: {
   ^^^
 
 note: # Disassembly
-%foo = func(%t:hlsl.byte_address_buffer<read>):u32 {
-  $B1: {
+$B1: {  # root
+  %t:hlsl.byte_address_buffer<read> = var @binding_point(0, 0)
+}
+
+%foo = func():u32 {
+  $B2: {
     %3:u32 = %t.Store 2u, 2u
     ret %3
   }
diff --git a/src/tint/utils/containers/vector.h b/src/tint/utils/containers/vector.h
index 23b5b6d..74f7129 100644
--- a/src/tint/utils/containers/vector.h
+++ b/src/tint/utils/containers/vector.h
@@ -1161,6 +1161,13 @@
         return hash;
     }
 
+    /// @returns true if the predicate function returns true for any of the elements of the vector
+    /// @param pred a function-like with the signature `bool(T)`
+    template <typename PREDICATE>
+    bool Any(PREDICATE&& pred) const {
+        return std::any_of(begin(), end(), std::forward<PREDICATE>(pred));
+    }
+
   private:
     /// Friend class
     template <typename, size_t>