[tint][ir] Add Capability::kAllowRefTypes

Validate that references are not used without this capability enabled.

Change-Id: Ie68cf2f8924d00f56deacc4f05222c6c347bfab6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/181361
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/lang/core/ir/transform/helper_test.h b/src/tint/lang/core/ir/transform/helper_test.h
index 6b1bdfa..84e0f7e 100644
--- a/src/tint/lang/core/ir/transform/helper_test.h
+++ b/src/tint/lang/core/ir/transform/helper_test.h
@@ -37,6 +37,7 @@
 #include "src/tint/lang/core/ir/builder.h"
 #include "src/tint/lang/core/ir/disassembler.h"
 #include "src/tint/lang/core/ir/validator.h"
+#include "src/tint/utils/containers/enum_set.h"
 
 namespace tint::core::ir::transform {
 
@@ -57,7 +58,7 @@
         }
 
         // Validate the output IR.
-        EXPECT_EQ(ir::Validate(mod), Success);
+        EXPECT_EQ(ir::Validate(mod, capabilities), Success);
     }
 
     /// @returns the transformed module as a disassembled string
@@ -70,6 +71,8 @@
     ir::Builder b{mod};
     /// The type manager.
     core::type::Manager& ty{mod.Types()};
+    /// IR validation capabilities
+    Capabilities capabilities;
 };
 
 using TransformTest = TransformTestBase<testing::Test>;
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index 24bfee7..0a61365 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -27,6 +27,7 @@
 
 #include "src/tint/lang/core/ir/validator.h"
 
+#include <cstdint>
 #include <memory>
 #include <string>
 #include <utility>
@@ -68,6 +69,7 @@
 #include "src/tint/lang/core/type/memory_view.h"
 #include "src/tint/lang/core/type/pointer.h"
 #include "src/tint/lang/core/type/reference.h"
+#include "src/tint/lang/core/type/type.h"
 #include "src/tint/lang/core/type/vector.h"
 #include "src/tint/lang/core/type/void.h"
 #include "src/tint/utils/containers/reverse.h"
@@ -88,13 +90,47 @@
 
 namespace {
 
+/// @returns true if the type @p type is of, or indirectly references a type of type `T`.
+template <typename T>
+bool HoldsType(const type::Type* type) {
+    if (!type) {
+        return false;
+    }
+    Vector<const type::Type*, 8> stack{type};
+    Hashset<const type::Type*, 8> seen{type};
+    while (!stack.IsEmpty()) {
+        auto* ty = stack.Pop();
+        if (ty->Is<T>()) {
+            return true;
+        }
+
+        if (auto* view = ty->As<type::MemoryView>(); view && seen.Add(view)) {
+            stack.Push(view);
+            continue;
+        }
+
+        auto type_count = ty->Elements();
+        if (type_count.type && seen.Add(type_count.type)) {
+            stack.Push(type_count.type);
+            continue;
+        }
+
+        for (uint32_t i = 0; i < type_count.count; i++) {
+            if (auto* subtype = ty->Element(i); subtype && seen.Add(subtype)) {
+                stack.Push(subtype);
+            }
+        }
+    }
+    return false;
+}
+
 /// The core IR validator.
 class Validator {
   public:
     /// Create a core validator
     /// @param mod the module to be validated
     /// @param capabilities the optional capabilities that are allowed
-    explicit Validator(const Module& mod, EnumSet<Capability> capabilities);
+    explicit Validator(const Module& mod, Capabilities capabilities);
 
     /// Destructor
     ~Validator();
@@ -281,7 +317,7 @@
 
   private:
     const Module& mod_;
-    EnumSet<Capability> capabilities_;
+    Capabilities capabilities_;
     std::shared_ptr<Source::File> disassembly_file;
     diag::List diagnostics_;
     Disassembler dis_{mod_};
@@ -293,7 +329,7 @@
     void DisassembleIfNeeded();
 };
 
-Validator::Validator(const Module& mod, EnumSet<Capability> capabilities)
+Validator::Validator(const Module& mod, Capabilities capabilities)
     : mod_(mod), capabilities_(capabilities) {}
 
 Validator::~Validator() = default;
@@ -451,6 +487,18 @@
 
 void Validator::CheckFunction(const Function* func) {
     CheckBlock(func->Block());
+
+    // References not allowed on function signatures even with Capability::kAllowRefTypes
+    for (auto* param : func->Params()) {
+        if (HoldsType<type::Reference>(param->Type())) {
+            // TODO(dsinclair): Parameters need a source mapping.
+            AddError(Source{}) << "references are not permitted as parameter types";
+        }
+    }
+    if (HoldsType<type::Reference>(func->ReturnType())) {
+        // TODO(dsinclair): Function need a source mapping.
+        AddError(Source{}) << "references are not permitted as return types";
+    }
 }
 
 void Validator::CheckBlock(const Block* blk) {
@@ -494,6 +542,12 @@
         } else if (res->Instruction() != inst) {
             AddResultError(inst, i) << "instruction of result is a different instruction";
         }
+
+        if (!capabilities_.Contains(Capability::kAllowRefTypes)) {
+            if (HoldsType<type::Reference>(res->Type())) {
+                AddResultError(inst, i) << "reference type is not permitted";
+            }
+        }
     }
 
     auto ops = inst->Operands();
@@ -512,6 +566,12 @@
         if (!op->HasUsage(inst, i)) {
             AddError(inst, i) << "operand missing usage";
         }
+
+        if (!capabilities_.Contains(Capability::kAllowRefTypes)) {
+            if (HoldsType<type::Reference>(op->Type())) {
+                AddError(inst, i) << "reference type is not permitted";
+            }
+        }
     }
 
     tint::Switch(
@@ -1029,14 +1089,14 @@
 
 }  // namespace
 
-Result<SuccessType> Validate(const Module& mod, EnumSet<Capability> capabilities) {
+Result<SuccessType> Validate(const Module& mod, Capabilities capabilities) {
     Validator v(mod, capabilities);
     return v.Run();
 }
 
 Result<SuccessType> ValidateAndDumpIfNeeded([[maybe_unused]] const Module& ir,
                                             [[maybe_unused]] const char* msg,
-                                            [[maybe_unused]] EnumSet<Capability> capabilities) {
+                                            [[maybe_unused]] Capabilities capabilities) {
 #if TINT_DUMP_IR_WHEN_VALIDATING
     std::cout << "=========================================================" << std::endl;
     std::cout << "== IR dump before " << msg << ":" << std::endl;
diff --git a/src/tint/lang/core/ir/validator.h b/src/tint/lang/core/ir/validator.h
index ba854d8..6e4d475 100644
--- a/src/tint/lang/core/ir/validator.h
+++ b/src/tint/lang/core/ir/validator.h
@@ -44,13 +44,18 @@
 enum class Capability {
     /// Allows access instructions to create pointers to vector elements.
     kAllowVectorElementPointer,
+    /// Allows ref types
+    kAllowRefTypes,
 };
 
+/// Capabilities is a set of Capability
+using Capabilities = EnumSet<Capability>;
+
 /// Validates that a given IR module is correctly formed
 /// @param mod the module to validate
 /// @param capabilities the optional capabilities that are allowed
 /// @returns success or failure
-Result<SuccessType> Validate(const Module& mod, EnumSet<Capability> capabilities = {});
+Result<SuccessType> Validate(const Module& mod, Capabilities capabilities = {});
 
 /// Validates the module @p ir and dumps its contents if required by the build configuration.
 /// @param ir the module to transform
@@ -59,7 +64,7 @@
 /// @returns success or failure
 Result<SuccessType> ValidateAndDumpIfNeeded(const Module& ir,
                                             const char* msg,
-                                            EnumSet<Capability> capabilities = {});
+                                            Capabilities capabilities = {});
 
 }  // namespace tint::core::ir
 
diff --git a/src/tint/lang/core/ir/validator_test.cc b/src/tint/lang/core/ir/validator_test.cc
index 256cf8f..e049e4d 100644
--- a/src/tint/lang/core/ir/validator_test.cc
+++ b/src/tint/lang/core/ir/validator_test.cc
@@ -26,15 +26,22 @@
 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 #include <string>
+#include <tuple>
 #include <utility>
 
 #include "gmock/gmock.h"
+
+#include "src/tint/lang/core/address_space.h"
 #include "src/tint/lang/core/ir/builder.h"
+#include "src/tint/lang/core/ir/function_param.h"
 #include "src/tint/lang/core/ir/ir_helper_test.h"
 #include "src/tint/lang/core/ir/validator.h"
 #include "src/tint/lang/core/type/array.h"
+#include "src/tint/lang/core/type/manager.h"
 #include "src/tint/lang/core/type/matrix.h"
+#include "src/tint/lang/core/type/memory_view.h"
 #include "src/tint/lang/core/type/pointer.h"
+#include "src/tint/lang/core/type/reference.h"
 #include "src/tint/lang/core/type/struct.h"
 #include "src/tint/utils/text/string.h"
 
@@ -770,7 +777,7 @@
         b.Return(f);
     });
 
-    auto res = ir::Validate(mod, EnumSet<Capability>{Capability::kAllowVectorElementPointer});
+    auto res = ir::Validate(mod, Capabilities{Capability::kAllowVectorElementPointer});
     ASSERT_EQ(res, Success);
 }
 
@@ -815,7 +822,7 @@
         b.Return(f);
     });
 
-    auto res = ir::Validate(mod, EnumSet<Capability>{Capability::kAllowVectorElementPointer});
+    auto res = ir::Validate(mod, Capabilities{Capability::kAllowVectorElementPointer});
     ASSERT_EQ(res, Success);
 }
 
@@ -3518,5 +3525,112 @@
 )");
 }
 
+template <typename T>
+static const type::Type* TypeBuilder(type::Manager& m) {
+    return m.Get<T>();
+}
+template <typename T>
+static const type::Type* RefTypeBuilder(type::Manager& m) {
+    return m.ref<AddressSpace::kFunction, T>();
+}
+using TypeBuilderFn = decltype(&TypeBuilder<i32>);
+
+using IR_ValidatorRefTypeTest = IRTestParamHelper<std::tuple</* holds_ref */ bool,
+                                                             /* refs_allowed */ bool,
+                                                             /* type_builder */ TypeBuilderFn>>;
+
+TEST_P(IR_ValidatorRefTypeTest, Var) {
+    bool holds_ref = std::get<0>(GetParam());
+    bool refs_allowed = std::get<1>(GetParam());
+    auto* type = std::get<2>(GetParam())(ty);
+
+    auto* fn = b.Function("my_func", ty.void_());
+    b.Append(fn->Block(), [&] {
+        if (auto* view = type->As<type::MemoryView>()) {
+            b.Var(view);
+        } else {
+            b.Var(ty.ptr<function>(type));
+        }
+
+        b.Return(fn);
+    });
+
+    Capabilities caps;
+    if (refs_allowed) {
+        caps.Add(Capability::kAllowRefTypes);
+    }
+    auto res = ir::Validate(mod, caps);
+    if (!holds_ref || refs_allowed) {
+        ASSERT_EQ(res, Success) << res.Failure();
+    } else {
+        ASSERT_NE(res, Success);
+        EXPECT_THAT(res.Failure().reason.Str(),
+                    testing::HasSubstr("3:5 error: var: reference type is not permitted"));
+    }
+}
+
+TEST_P(IR_ValidatorRefTypeTest, FnParam) {
+    bool holds_ref = std::get<0>(GetParam());
+    bool refs_allowed = std::get<1>(GetParam());
+    auto* type = std::get<2>(GetParam())(ty);
+
+    auto* fn = b.Function("my_func", ty.void_());
+    fn->SetParams(Vector{b.FunctionParam(type)});
+    b.Append(fn->Block(), [&] { b.Return(fn); });
+
+    Capabilities caps;
+    if (refs_allowed) {
+        caps.Add(Capability::kAllowRefTypes);
+    }
+    auto res = ir::Validate(mod, caps);
+    if (!holds_ref) {
+        ASSERT_EQ(res, Success) << res.Failure();
+    } else {
+        ASSERT_NE(res, Success);
+        EXPECT_THAT(res.Failure().reason.Str(),
+                    testing::HasSubstr("references are not permitted as parameter types"));
+    }
+}
+
+TEST_P(IR_ValidatorRefTypeTest, FnRet) {
+    bool holds_ref = std::get<0>(GetParam());
+    bool refs_allowed = std::get<1>(GetParam());
+    auto* type = std::get<2>(GetParam())(ty);
+
+    auto* fn = b.Function("my_func", type);
+    b.Append(fn->Block(), [&] { b.Unreachable(); });
+
+    Capabilities caps;
+    if (refs_allowed) {
+        caps.Add(Capability::kAllowRefTypes);
+    }
+    auto res = ir::Validate(mod, caps);
+    if (!holds_ref) {
+        ASSERT_EQ(res, Success) << res.Failure();
+    } else {
+        ASSERT_NE(res, Success);
+        EXPECT_THAT(res.Failure().reason.Str(),
+                    testing::HasSubstr("references are not permitted as return types"));
+    }
+}
+
+INSTANTIATE_TEST_SUITE_P(NonRefTypes,
+                         IR_ValidatorRefTypeTest,
+                         testing::Combine(/* holds_ref */ testing::Values(false),
+                                          /* refs_allowed */ testing::Values(false, true),
+                                          /* type_builder */
+                                          testing::Values(TypeBuilder<i32>,
+                                                          TypeBuilder<bool>,
+                                                          TypeBuilder<vec4<f32>>,
+                                                          TypeBuilder<array<f32, 3>>)));
+
+INSTANTIATE_TEST_SUITE_P(RefTypes,
+                         IR_ValidatorRefTypeTest,
+                         testing::Combine(/* holds_ref */ testing::Values(true),
+                                          /* refs_allowed */ testing::Values(false, true),
+                                          /* type_builder */
+                                          testing::Values(RefTypeBuilder<i32>,
+                                                          RefTypeBuilder<bool>,
+                                                          RefTypeBuilder<vec4<f32>>)));
 }  // namespace
 }  // namespace tint::core::ir
diff --git a/src/tint/lang/spirv/reader/lower/vector_element_pointer.cc b/src/tint/lang/spirv/reader/lower/vector_element_pointer.cc
index ad0d3f3..1dee037 100644
--- a/src/tint/lang/spirv/reader/lower/vector_element_pointer.cc
+++ b/src/tint/lang/spirv/reader/lower/vector_element_pointer.cc
@@ -163,7 +163,7 @@
 
 Result<SuccessType> VectorElementPointer(core::ir::Module& ir) {
     auto result = ValidateAndDumpIfNeeded(ir, "VectorElementPointer transform",
-                                          EnumSet<core::ir::Capability>{
+                                          core::ir::Capabilities{
                                               core::ir::Capability::kAllowVectorElementPointer,
                                           });
     if (result != Success) {
diff --git a/src/tint/lang/spirv/reader/parser/helper_test.h b/src/tint/lang/spirv/reader/parser/helper_test.h
index 6842aee..1aece7f 100644
--- a/src/tint/lang/spirv/reader/parser/helper_test.h
+++ b/src/tint/lang/spirv/reader/parser/helper_test.h
@@ -74,7 +74,7 @@
 
         // Validate the IR module against the capabilities supported by the SPIR-V dialect.
         auto validated =
-            core::ir::Validate(parsed.Get(), EnumSet<core::ir::Capability>{
+            core::ir::Validate(parsed.Get(), core::ir::Capabilities{
                                                  core::ir::Capability::kAllowVectorElementPointer,
                                              });
         if (validated != Success) {
diff --git a/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc b/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc
index 08d7035..8bf72e7 100644
--- a/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc
+++ b/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc
@@ -105,7 +105,8 @@
     explicit State(const core::ir::Module& m) : mod(m) {}
 
     Program Run(const ProgramOptions& options) {
-        if (auto res = core::ir::Validate(mod); res != Success) {
+        core::ir::Capabilities caps{core::ir::Capability::kAllowRefTypes};
+        if (auto res = core::ir::Validate(mod, caps); res != Success) {
             // IR module failed validation.
             b.Diagnostics() = res.Failure().reason;
             return Program{resolver::Resolve(b)};
diff --git a/src/tint/lang/wgsl/writer/raise/ptr_to_ref_test.cc b/src/tint/lang/wgsl/writer/raise/ptr_to_ref_test.cc
index 814aefd..851f9bb 100644
--- a/src/tint/lang/wgsl/writer/raise/ptr_to_ref_test.cc
+++ b/src/tint/lang/wgsl/writer/raise/ptr_to_ref_test.cc
@@ -59,7 +59,8 @@
         EXPECT_EQ(result, Success);
 
         // Validate the output IR.
-        auto res = core::ir::Validate(mod);
+        core::ir::Capabilities caps{core::ir::Capability::kAllowRefTypes};
+        auto res = core::ir::Validate(mod, caps);
         EXPECT_EQ(res, Success);
     }
 
diff --git a/src/tint/lang/wgsl/writer/raise/raise_test.cc b/src/tint/lang/wgsl/writer/raise/raise_test.cc
index 013d978..0e5cae6 100644
--- a/src/tint/lang/wgsl/writer/raise/raise_test.cc
+++ b/src/tint/lang/wgsl/writer/raise/raise_test.cc
@@ -28,6 +28,7 @@
 #include <utility>
 
 #include "src/tint/lang/core/ir/transform/helper_test.h"
+#include "src/tint/lang/core/ir/validator.h"
 #include "src/tint/lang/core/type/struct.h"
 #include "src/tint/lang/wgsl/writer/raise/raise.h"
 
@@ -37,7 +38,10 @@
 using namespace tint::core::fluent_types;     // NOLINT
 using namespace tint::core::number_suffixes;  // NOLINT
 
-using WgslWriter_RaiseTest = core::ir::transform::TransformTest;
+class WgslWriter_RaiseTest : public core::ir::transform::TransformTest {
+  public:
+    WgslWriter_RaiseTest() { capabilities.Add(core::ir::Capability::kAllowRefTypes); }
+};
 
 TEST_F(WgslWriter_RaiseTest, BuiltinConversion) {
     auto* f = b.Function("f", ty.void_());