[ir] Add vector-element-ptr validation capability

The IR validation APIs now take an optional list of allowed
capabilities which can be used to relax certain core IR validation
rules.

The `kVectorElementPointer` rule will be used by the SPIR-V dialect.

Bug: tint:1952
Change-Id: Id1317e9f6970f2d2be4b83650d7249f2e08fc379
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/170001
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index 2aa16f8..65c952a 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -90,7 +90,8 @@
   public:
     /// Create a core validator
     /// @param mod the module to be validated
-    explicit Validator(const Module& mod);
+    /// @param capabilities the optional capabilities that are allowed
+    explicit Validator(const Module& mod, EnumSet<Capability> capabilities);
 
     /// Destructor
     ~Validator();
@@ -286,6 +287,7 @@
 
   private:
     const Module& mod_;
+    EnumSet<Capability> capabilities_;
     std::shared_ptr<Source::File> disassembly_file;
     diag::List diagnostics_;
     Disassembler dis_{mod_};
@@ -297,7 +299,8 @@
     void DisassembleIfNeeded();
 };
 
-Validator::Validator(const Module& mod) : mod_(mod) {}
+Validator::Validator(const Module& mod, EnumSet<Capability> capabilities)
+    : mod_(mod), capabilities_(capabilities) {}
 
 Validator::~Validator() = default;
 
@@ -651,9 +654,11 @@
             return;
         }
 
-        if (obj_ptr && el_ty->Is<core::type::Vector>()) {
-            err("cannot obtain address of vector element");
-            return;
+        if (!capabilities_.Contains(Capability::kAllowVectorElementPointer)) {
+            if (obj_ptr && el_ty->Is<core::type::Vector>()) {
+                err("cannot obtain address of vector element");
+                return;
+            }
         }
 
         if (auto* const_index = index->As<ir::Constant>()) {
@@ -1024,13 +1029,14 @@
 
 }  // namespace
 
-Result<SuccessType> Validate(const Module& mod) {
-    Validator v(mod);
+Result<SuccessType> Validate(const Module& mod, EnumSet<Capability> capabilities) {
+    Validator v(mod, capabilities);
     return v.Run();
 }
 
 Result<SuccessType> ValidateAndDumpIfNeeded([[maybe_unused]] const Module& ir,
-                                            [[maybe_unused]] const char* msg) {
+                                            [[maybe_unused]] const char* msg,
+                                            [[maybe_unused]] EnumSet<Capability> capabilities) {
 #if TINT_DUMP_IR_WHEN_VALIDATING
     std::cout << "=========================================================" << std::endl;
     std::cout << "== IR dump before " << msg << ":" << std::endl;
@@ -1039,7 +1045,7 @@
 #endif
 
 #ifndef NDEBUG
-    auto result = Validate(ir);
+    auto result = Validate(ir, capabilities);
     if (result != Success) {
         return result.Failure();
     }
diff --git a/src/tint/lang/core/ir/validator.h b/src/tint/lang/core/ir/validator.h
index d6b0843..ba854d8 100644
--- a/src/tint/lang/core/ir/validator.h
+++ b/src/tint/lang/core/ir/validator.h
@@ -30,6 +30,7 @@
 
 #include <string>
 
+#include "src/tint/utils/containers/enum_set.h"
 #include "src/tint/utils/result/result.h"
 
 // Forward declarations
@@ -39,16 +40,26 @@
 
 namespace tint::core::ir {
 
+/// Enumerator of optional IR capabilities.
+enum class Capability {
+    /// Allows access instructions to create pointers to vector elements.
+    kAllowVectorElementPointer,
+};
+
 /// Validates that a given IR module is correctly formed
 /// @param mod the module to validate
+/// @param capabilities the optional capabilities that are allowed
 /// @returns success or failure
-Result<SuccessType> Validate(const Module& mod);
+Result<SuccessType> Validate(const Module& mod, EnumSet<Capability> capabilities = {});
 
 /// Validates the module @p ir and dumps its contents if required by the build configuration.
 /// @param ir the module to transform
 /// @param msg the msg to accompany the output
+/// @param capabilities the optional capabilities that are allowed
 /// @returns success or failure
-Result<SuccessType> ValidateAndDumpIfNeeded(const Module& ir, const char* msg);
+Result<SuccessType> ValidateAndDumpIfNeeded(const Module& ir,
+                                            const char* msg,
+                                            EnumSet<Capability> capabilities = {});
 
 }  // namespace tint::core::ir
 
diff --git a/src/tint/lang/core/ir/validator_test.cc b/src/tint/lang/core/ir/validator_test.cc
index 3c2733a..0c2bef9 100644
--- a/src/tint/lang/core/ir/validator_test.cc
+++ b/src/tint/lang/core/ir/validator_test.cc
@@ -760,6 +760,20 @@
 )");
 }
 
+TEST_F(IR_ValidatorTest, Access_IndexVectorPtr_WithCapability) {
+    auto* f = b.Function("my_func", ty.void_());
+    auto* obj = b.FunctionParam(ty.ptr<private_, vec3<f32>>());
+    f->SetParams({obj});
+
+    b.Append(f->Block(), [&] {
+        b.Access(ty.ptr<private_, f32>(), obj, 1_u);
+        b.Return(f);
+    });
+
+    auto res = ir::Validate(mod, EnumSet<Capability>{Capability::kAllowVectorElementPointer});
+    ASSERT_EQ(res, Success);
+}
+
 TEST_F(IR_ValidatorTest, Access_IndexVectorPtr_ViaMatrixPtr) {
     auto* f = b.Function("my_func", ty.void_());
     auto* obj = b.FunctionParam(ty.ptr<private_, mat3x2<f32>>());
@@ -791,6 +805,20 @@
 )");
 }
 
+TEST_F(IR_ValidatorTest, Access_IndexVectorPtr_ViaMatrixPtr_WithCapability) {
+    auto* f = b.Function("my_func", ty.void_());
+    auto* obj = b.FunctionParam(ty.ptr<private_, mat3x2<f32>>());
+    f->SetParams({obj});
+
+    b.Append(f->Block(), [&] {
+        b.Access(ty.ptr<private_, f32>(), obj, 1_u, 1_u);
+        b.Return(f);
+    });
+
+    auto res = ir::Validate(mod, EnumSet<Capability>{Capability::kAllowVectorElementPointer});
+    ASSERT_EQ(res, Success);
+}
+
 TEST_F(IR_ValidatorTest, Access_Incorrect_Ptr_AddressSpace) {
     auto* f = b.Function("my_func", ty.void_());
     auto* obj = b.FunctionParam(ty.ptr<storage, array<f32, 2>, read>());