[ir] Add vector-element-ptr validation capability
The IR validation APIs now take an optional list of allowed
capabilities which can be used to relax certain core IR validation
rules.
The `kVectorElementPointer` rule will be used by the SPIR-V dialect.
Bug: tint:1952
Change-Id: Id1317e9f6970f2d2be4b83650d7249f2e08fc379
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/170001
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index 2aa16f8..65c952a 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -90,7 +90,8 @@
public:
/// Create a core validator
/// @param mod the module to be validated
- explicit Validator(const Module& mod);
+ /// @param capabilities the optional capabilities that are allowed
+ explicit Validator(const Module& mod, EnumSet<Capability> capabilities);
/// Destructor
~Validator();
@@ -286,6 +287,7 @@
private:
const Module& mod_;
+ EnumSet<Capability> capabilities_;
std::shared_ptr<Source::File> disassembly_file;
diag::List diagnostics_;
Disassembler dis_{mod_};
@@ -297,7 +299,8 @@
void DisassembleIfNeeded();
};
-Validator::Validator(const Module& mod) : mod_(mod) {}
+Validator::Validator(const Module& mod, EnumSet<Capability> capabilities)
+ : mod_(mod), capabilities_(capabilities) {}
Validator::~Validator() = default;
@@ -651,9 +654,11 @@
return;
}
- if (obj_ptr && el_ty->Is<core::type::Vector>()) {
- err("cannot obtain address of vector element");
- return;
+ if (!capabilities_.Contains(Capability::kAllowVectorElementPointer)) {
+ if (obj_ptr && el_ty->Is<core::type::Vector>()) {
+ err("cannot obtain address of vector element");
+ return;
+ }
}
if (auto* const_index = index->As<ir::Constant>()) {
@@ -1024,13 +1029,14 @@
} // namespace
-Result<SuccessType> Validate(const Module& mod) {
- Validator v(mod);
+Result<SuccessType> Validate(const Module& mod, EnumSet<Capability> capabilities) {
+ Validator v(mod, capabilities);
return v.Run();
}
Result<SuccessType> ValidateAndDumpIfNeeded([[maybe_unused]] const Module& ir,
- [[maybe_unused]] const char* msg) {
+ [[maybe_unused]] const char* msg,
+ [[maybe_unused]] EnumSet<Capability> capabilities) {
#if TINT_DUMP_IR_WHEN_VALIDATING
std::cout << "=========================================================" << std::endl;
std::cout << "== IR dump before " << msg << ":" << std::endl;
@@ -1039,7 +1045,7 @@
#endif
#ifndef NDEBUG
- auto result = Validate(ir);
+ auto result = Validate(ir, capabilities);
if (result != Success) {
return result.Failure();
}
diff --git a/src/tint/lang/core/ir/validator.h b/src/tint/lang/core/ir/validator.h
index d6b0843..ba854d8 100644
--- a/src/tint/lang/core/ir/validator.h
+++ b/src/tint/lang/core/ir/validator.h
@@ -30,6 +30,7 @@
#include <string>
+#include "src/tint/utils/containers/enum_set.h"
#include "src/tint/utils/result/result.h"
// Forward declarations
@@ -39,16 +40,26 @@
namespace tint::core::ir {
+/// Enumerator of optional IR capabilities.
+enum class Capability {
+ /// Allows access instructions to create pointers to vector elements.
+ kAllowVectorElementPointer,
+};
+
/// Validates that a given IR module is correctly formed
/// @param mod the module to validate
+/// @param capabilities the optional capabilities that are allowed
/// @returns success or failure
-Result<SuccessType> Validate(const Module& mod);
+Result<SuccessType> Validate(const Module& mod, EnumSet<Capability> capabilities = {});
/// Validates the module @p ir and dumps its contents if required by the build configuration.
/// @param ir the module to transform
/// @param msg the msg to accompany the output
+/// @param capabilities the optional capabilities that are allowed
/// @returns success or failure
-Result<SuccessType> ValidateAndDumpIfNeeded(const Module& ir, const char* msg);
+Result<SuccessType> ValidateAndDumpIfNeeded(const Module& ir,
+ const char* msg,
+ EnumSet<Capability> capabilities = {});
} // namespace tint::core::ir
diff --git a/src/tint/lang/core/ir/validator_test.cc b/src/tint/lang/core/ir/validator_test.cc
index 3c2733a..0c2bef9 100644
--- a/src/tint/lang/core/ir/validator_test.cc
+++ b/src/tint/lang/core/ir/validator_test.cc
@@ -760,6 +760,20 @@
)");
}
+TEST_F(IR_ValidatorTest, Access_IndexVectorPtr_WithCapability) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* obj = b.FunctionParam(ty.ptr<private_, vec3<f32>>());
+ f->SetParams({obj});
+
+ b.Append(f->Block(), [&] {
+ b.Access(ty.ptr<private_, f32>(), obj, 1_u);
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod, EnumSet<Capability>{Capability::kAllowVectorElementPointer});
+ ASSERT_EQ(res, Success);
+}
+
TEST_F(IR_ValidatorTest, Access_IndexVectorPtr_ViaMatrixPtr) {
auto* f = b.Function("my_func", ty.void_());
auto* obj = b.FunctionParam(ty.ptr<private_, mat3x2<f32>>());
@@ -791,6 +805,20 @@
)");
}
+TEST_F(IR_ValidatorTest, Access_IndexVectorPtr_ViaMatrixPtr_WithCapability) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* obj = b.FunctionParam(ty.ptr<private_, mat3x2<f32>>());
+ f->SetParams({obj});
+
+ b.Append(f->Block(), [&] {
+ b.Access(ty.ptr<private_, f32>(), obj, 1_u, 1_u);
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod, EnumSet<Capability>{Capability::kAllowVectorElementPointer});
+ ASSERT_EQ(res, Success);
+}
+
TEST_F(IR_ValidatorTest, Access_Incorrect_Ptr_AddressSpace) {
auto* f = b.Function("my_func", ty.void_());
auto* obj = b.FunctionParam(ty.ptr<storage, array<f32, 2>, read>());