[tint][ir] Add Capability::kAllowRefTypes
Validate that references are not used without this capability enabled.
Change-Id: Ie68cf2f8924d00f56deacc4f05222c6c347bfab6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/181361
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/lang/core/ir/transform/helper_test.h b/src/tint/lang/core/ir/transform/helper_test.h
index 6b1bdfa..84e0f7e 100644
--- a/src/tint/lang/core/ir/transform/helper_test.h
+++ b/src/tint/lang/core/ir/transform/helper_test.h
@@ -37,6 +37,7 @@
#include "src/tint/lang/core/ir/builder.h"
#include "src/tint/lang/core/ir/disassembler.h"
#include "src/tint/lang/core/ir/validator.h"
+#include "src/tint/utils/containers/enum_set.h"
namespace tint::core::ir::transform {
@@ -57,7 +58,7 @@
}
// Validate the output IR.
- EXPECT_EQ(ir::Validate(mod), Success);
+ EXPECT_EQ(ir::Validate(mod, capabilities), Success);
}
/// @returns the transformed module as a disassembled string
@@ -70,6 +71,8 @@
ir::Builder b{mod};
/// The type manager.
core::type::Manager& ty{mod.Types()};
+ /// IR validation capabilities
+ Capabilities capabilities;
};
using TransformTest = TransformTestBase<testing::Test>;
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index 24bfee7..0a61365 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -27,6 +27,7 @@
#include "src/tint/lang/core/ir/validator.h"
+#include <cstdint>
#include <memory>
#include <string>
#include <utility>
@@ -68,6 +69,7 @@
#include "src/tint/lang/core/type/memory_view.h"
#include "src/tint/lang/core/type/pointer.h"
#include "src/tint/lang/core/type/reference.h"
+#include "src/tint/lang/core/type/type.h"
#include "src/tint/lang/core/type/vector.h"
#include "src/tint/lang/core/type/void.h"
#include "src/tint/utils/containers/reverse.h"
@@ -88,13 +90,47 @@
namespace {
+/// @returns true if the type @p type is of, or indirectly references a type of type `T`.
+template <typename T>
+bool HoldsType(const type::Type* type) {
+ if (!type) {
+ return false;
+ }
+ Vector<const type::Type*, 8> stack{type};
+ Hashset<const type::Type*, 8> seen{type};
+ while (!stack.IsEmpty()) {
+ auto* ty = stack.Pop();
+ if (ty->Is<T>()) {
+ return true;
+ }
+
+ if (auto* view = ty->As<type::MemoryView>(); view && seen.Add(view)) {
+ stack.Push(view);
+ continue;
+ }
+
+ auto type_count = ty->Elements();
+ if (type_count.type && seen.Add(type_count.type)) {
+ stack.Push(type_count.type);
+ continue;
+ }
+
+ for (uint32_t i = 0; i < type_count.count; i++) {
+ if (auto* subtype = ty->Element(i); subtype && seen.Add(subtype)) {
+ stack.Push(subtype);
+ }
+ }
+ }
+ return false;
+}
+
/// The core IR validator.
class Validator {
public:
/// Create a core validator
/// @param mod the module to be validated
/// @param capabilities the optional capabilities that are allowed
- explicit Validator(const Module& mod, EnumSet<Capability> capabilities);
+ explicit Validator(const Module& mod, Capabilities capabilities);
/// Destructor
~Validator();
@@ -281,7 +317,7 @@
private:
const Module& mod_;
- EnumSet<Capability> capabilities_;
+ Capabilities capabilities_;
std::shared_ptr<Source::File> disassembly_file;
diag::List diagnostics_;
Disassembler dis_{mod_};
@@ -293,7 +329,7 @@
void DisassembleIfNeeded();
};
-Validator::Validator(const Module& mod, EnumSet<Capability> capabilities)
+Validator::Validator(const Module& mod, Capabilities capabilities)
: mod_(mod), capabilities_(capabilities) {}
Validator::~Validator() = default;
@@ -451,6 +487,18 @@
void Validator::CheckFunction(const Function* func) {
CheckBlock(func->Block());
+
+ // References not allowed on function signatures even with Capability::kAllowRefTypes
+ for (auto* param : func->Params()) {
+ if (HoldsType<type::Reference>(param->Type())) {
+ // TODO(dsinclair): Parameters need a source mapping.
+ AddError(Source{}) << "references are not permitted as parameter types";
+ }
+ }
+ if (HoldsType<type::Reference>(func->ReturnType())) {
+ // TODO(dsinclair): Function need a source mapping.
+ AddError(Source{}) << "references are not permitted as return types";
+ }
}
void Validator::CheckBlock(const Block* blk) {
@@ -494,6 +542,12 @@
} else if (res->Instruction() != inst) {
AddResultError(inst, i) << "instruction of result is a different instruction";
}
+
+ if (!capabilities_.Contains(Capability::kAllowRefTypes)) {
+ if (HoldsType<type::Reference>(res->Type())) {
+ AddResultError(inst, i) << "reference type is not permitted";
+ }
+ }
}
auto ops = inst->Operands();
@@ -512,6 +566,12 @@
if (!op->HasUsage(inst, i)) {
AddError(inst, i) << "operand missing usage";
}
+
+ if (!capabilities_.Contains(Capability::kAllowRefTypes)) {
+ if (HoldsType<type::Reference>(op->Type())) {
+ AddError(inst, i) << "reference type is not permitted";
+ }
+ }
}
tint::Switch(
@@ -1029,14 +1089,14 @@
} // namespace
-Result<SuccessType> Validate(const Module& mod, EnumSet<Capability> capabilities) {
+Result<SuccessType> Validate(const Module& mod, Capabilities capabilities) {
Validator v(mod, capabilities);
return v.Run();
}
Result<SuccessType> ValidateAndDumpIfNeeded([[maybe_unused]] const Module& ir,
[[maybe_unused]] const char* msg,
- [[maybe_unused]] EnumSet<Capability> capabilities) {
+ [[maybe_unused]] Capabilities capabilities) {
#if TINT_DUMP_IR_WHEN_VALIDATING
std::cout << "=========================================================" << std::endl;
std::cout << "== IR dump before " << msg << ":" << std::endl;
diff --git a/src/tint/lang/core/ir/validator.h b/src/tint/lang/core/ir/validator.h
index ba854d8..6e4d475 100644
--- a/src/tint/lang/core/ir/validator.h
+++ b/src/tint/lang/core/ir/validator.h
@@ -44,13 +44,18 @@
enum class Capability {
/// Allows access instructions to create pointers to vector elements.
kAllowVectorElementPointer,
+ /// Allows ref types
+ kAllowRefTypes,
};
+/// Capabilities is a set of Capability
+using Capabilities = EnumSet<Capability>;
+
/// Validates that a given IR module is correctly formed
/// @param mod the module to validate
/// @param capabilities the optional capabilities that are allowed
/// @returns success or failure
-Result<SuccessType> Validate(const Module& mod, EnumSet<Capability> capabilities = {});
+Result<SuccessType> Validate(const Module& mod, Capabilities capabilities = {});
/// Validates the module @p ir and dumps its contents if required by the build configuration.
/// @param ir the module to transform
@@ -59,7 +64,7 @@
/// @returns success or failure
Result<SuccessType> ValidateAndDumpIfNeeded(const Module& ir,
const char* msg,
- EnumSet<Capability> capabilities = {});
+ Capabilities capabilities = {});
} // namespace tint::core::ir
diff --git a/src/tint/lang/core/ir/validator_test.cc b/src/tint/lang/core/ir/validator_test.cc
index 256cf8f..e049e4d 100644
--- a/src/tint/lang/core/ir/validator_test.cc
+++ b/src/tint/lang/core/ir/validator_test.cc
@@ -26,15 +26,22 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <string>
+#include <tuple>
#include <utility>
#include "gmock/gmock.h"
+
+#include "src/tint/lang/core/address_space.h"
#include "src/tint/lang/core/ir/builder.h"
+#include "src/tint/lang/core/ir/function_param.h"
#include "src/tint/lang/core/ir/ir_helper_test.h"
#include "src/tint/lang/core/ir/validator.h"
#include "src/tint/lang/core/type/array.h"
+#include "src/tint/lang/core/type/manager.h"
#include "src/tint/lang/core/type/matrix.h"
+#include "src/tint/lang/core/type/memory_view.h"
#include "src/tint/lang/core/type/pointer.h"
+#include "src/tint/lang/core/type/reference.h"
#include "src/tint/lang/core/type/struct.h"
#include "src/tint/utils/text/string.h"
@@ -770,7 +777,7 @@
b.Return(f);
});
- auto res = ir::Validate(mod, EnumSet<Capability>{Capability::kAllowVectorElementPointer});
+ auto res = ir::Validate(mod, Capabilities{Capability::kAllowVectorElementPointer});
ASSERT_EQ(res, Success);
}
@@ -815,7 +822,7 @@
b.Return(f);
});
- auto res = ir::Validate(mod, EnumSet<Capability>{Capability::kAllowVectorElementPointer});
+ auto res = ir::Validate(mod, Capabilities{Capability::kAllowVectorElementPointer});
ASSERT_EQ(res, Success);
}
@@ -3518,5 +3525,112 @@
)");
}
+template <typename T>
+static const type::Type* TypeBuilder(type::Manager& m) {
+ return m.Get<T>();
+}
+template <typename T>
+static const type::Type* RefTypeBuilder(type::Manager& m) {
+ return m.ref<AddressSpace::kFunction, T>();
+}
+using TypeBuilderFn = decltype(&TypeBuilder<i32>);
+
+using IR_ValidatorRefTypeTest = IRTestParamHelper<std::tuple</* holds_ref */ bool,
+ /* refs_allowed */ bool,
+ /* type_builder */ TypeBuilderFn>>;
+
+TEST_P(IR_ValidatorRefTypeTest, Var) {
+ bool holds_ref = std::get<0>(GetParam());
+ bool refs_allowed = std::get<1>(GetParam());
+ auto* type = std::get<2>(GetParam())(ty);
+
+ auto* fn = b.Function("my_func", ty.void_());
+ b.Append(fn->Block(), [&] {
+ if (auto* view = type->As<type::MemoryView>()) {
+ b.Var(view);
+ } else {
+ b.Var(ty.ptr<function>(type));
+ }
+
+ b.Return(fn);
+ });
+
+ Capabilities caps;
+ if (refs_allowed) {
+ caps.Add(Capability::kAllowRefTypes);
+ }
+ auto res = ir::Validate(mod, caps);
+ if (!holds_ref || refs_allowed) {
+ ASSERT_EQ(res, Success) << res.Failure();
+ } else {
+ ASSERT_NE(res, Success);
+ EXPECT_THAT(res.Failure().reason.Str(),
+ testing::HasSubstr("3:5 error: var: reference type is not permitted"));
+ }
+}
+
+TEST_P(IR_ValidatorRefTypeTest, FnParam) {
+ bool holds_ref = std::get<0>(GetParam());
+ bool refs_allowed = std::get<1>(GetParam());
+ auto* type = std::get<2>(GetParam())(ty);
+
+ auto* fn = b.Function("my_func", ty.void_());
+ fn->SetParams(Vector{b.FunctionParam(type)});
+ b.Append(fn->Block(), [&] { b.Return(fn); });
+
+ Capabilities caps;
+ if (refs_allowed) {
+ caps.Add(Capability::kAllowRefTypes);
+ }
+ auto res = ir::Validate(mod, caps);
+ if (!holds_ref) {
+ ASSERT_EQ(res, Success) << res.Failure();
+ } else {
+ ASSERT_NE(res, Success);
+ EXPECT_THAT(res.Failure().reason.Str(),
+ testing::HasSubstr("references are not permitted as parameter types"));
+ }
+}
+
+TEST_P(IR_ValidatorRefTypeTest, FnRet) {
+ bool holds_ref = std::get<0>(GetParam());
+ bool refs_allowed = std::get<1>(GetParam());
+ auto* type = std::get<2>(GetParam())(ty);
+
+ auto* fn = b.Function("my_func", type);
+ b.Append(fn->Block(), [&] { b.Unreachable(); });
+
+ Capabilities caps;
+ if (refs_allowed) {
+ caps.Add(Capability::kAllowRefTypes);
+ }
+ auto res = ir::Validate(mod, caps);
+ if (!holds_ref) {
+ ASSERT_EQ(res, Success) << res.Failure();
+ } else {
+ ASSERT_NE(res, Success);
+ EXPECT_THAT(res.Failure().reason.Str(),
+ testing::HasSubstr("references are not permitted as return types"));
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(NonRefTypes,
+ IR_ValidatorRefTypeTest,
+ testing::Combine(/* holds_ref */ testing::Values(false),
+ /* refs_allowed */ testing::Values(false, true),
+ /* type_builder */
+ testing::Values(TypeBuilder<i32>,
+ TypeBuilder<bool>,
+ TypeBuilder<vec4<f32>>,
+ TypeBuilder<array<f32, 3>>)));
+
+INSTANTIATE_TEST_SUITE_P(RefTypes,
+ IR_ValidatorRefTypeTest,
+ testing::Combine(/* holds_ref */ testing::Values(true),
+ /* refs_allowed */ testing::Values(false, true),
+ /* type_builder */
+ testing::Values(RefTypeBuilder<i32>,
+ RefTypeBuilder<bool>,
+ RefTypeBuilder<vec4<f32>>)));
} // namespace
} // namespace tint::core::ir
diff --git a/src/tint/lang/spirv/reader/lower/vector_element_pointer.cc b/src/tint/lang/spirv/reader/lower/vector_element_pointer.cc
index ad0d3f3..1dee037 100644
--- a/src/tint/lang/spirv/reader/lower/vector_element_pointer.cc
+++ b/src/tint/lang/spirv/reader/lower/vector_element_pointer.cc
@@ -163,7 +163,7 @@
Result<SuccessType> VectorElementPointer(core::ir::Module& ir) {
auto result = ValidateAndDumpIfNeeded(ir, "VectorElementPointer transform",
- EnumSet<core::ir::Capability>{
+ core::ir::Capabilities{
core::ir::Capability::kAllowVectorElementPointer,
});
if (result != Success) {
diff --git a/src/tint/lang/spirv/reader/parser/helper_test.h b/src/tint/lang/spirv/reader/parser/helper_test.h
index 6842aee..1aece7f 100644
--- a/src/tint/lang/spirv/reader/parser/helper_test.h
+++ b/src/tint/lang/spirv/reader/parser/helper_test.h
@@ -74,7 +74,7 @@
// Validate the IR module against the capabilities supported by the SPIR-V dialect.
auto validated =
- core::ir::Validate(parsed.Get(), EnumSet<core::ir::Capability>{
+ core::ir::Validate(parsed.Get(), core::ir::Capabilities{
core::ir::Capability::kAllowVectorElementPointer,
});
if (validated != Success) {
diff --git a/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc b/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc
index 08d7035..8bf72e7 100644
--- a/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc
+++ b/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc
@@ -105,7 +105,8 @@
explicit State(const core::ir::Module& m) : mod(m) {}
Program Run(const ProgramOptions& options) {
- if (auto res = core::ir::Validate(mod); res != Success) {
+ core::ir::Capabilities caps{core::ir::Capability::kAllowRefTypes};
+ if (auto res = core::ir::Validate(mod, caps); res != Success) {
// IR module failed validation.
b.Diagnostics() = res.Failure().reason;
return Program{resolver::Resolve(b)};
diff --git a/src/tint/lang/wgsl/writer/raise/ptr_to_ref_test.cc b/src/tint/lang/wgsl/writer/raise/ptr_to_ref_test.cc
index 814aefd..851f9bb 100644
--- a/src/tint/lang/wgsl/writer/raise/ptr_to_ref_test.cc
+++ b/src/tint/lang/wgsl/writer/raise/ptr_to_ref_test.cc
@@ -59,7 +59,8 @@
EXPECT_EQ(result, Success);
// Validate the output IR.
- auto res = core::ir::Validate(mod);
+ core::ir::Capabilities caps{core::ir::Capability::kAllowRefTypes};
+ auto res = core::ir::Validate(mod, caps);
EXPECT_EQ(res, Success);
}
diff --git a/src/tint/lang/wgsl/writer/raise/raise_test.cc b/src/tint/lang/wgsl/writer/raise/raise_test.cc
index 013d978..0e5cae6 100644
--- a/src/tint/lang/wgsl/writer/raise/raise_test.cc
+++ b/src/tint/lang/wgsl/writer/raise/raise_test.cc
@@ -28,6 +28,7 @@
#include <utility>
#include "src/tint/lang/core/ir/transform/helper_test.h"
+#include "src/tint/lang/core/ir/validator.h"
#include "src/tint/lang/core/type/struct.h"
#include "src/tint/lang/wgsl/writer/raise/raise.h"
@@ -37,7 +38,10 @@
using namespace tint::core::fluent_types; // NOLINT
using namespace tint::core::number_suffixes; // NOLINT
-using WgslWriter_RaiseTest = core::ir::transform::TransformTest;
+class WgslWriter_RaiseTest : public core::ir::transform::TransformTest {
+ public:
+ WgslWriter_RaiseTest() { capabilities.Add(core::ir::Capability::kAllowRefTypes); }
+};
TEST_F(WgslWriter_RaiseTest, BuiltinConversion) {
auto* f = b.Function("f", ty.void_());