[ir] Allow access to extract pointers from structs
The validator previously required that only `access` instructions with
a pointer source operand can produce a pointer result. When extracting
a pointer member from a struct value this is not the case, so update
the validation logic to allow for this.
Change-Id: I6c98427746d1cdb98c4d9cf4c61a76dc6ccfbcd4
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/188340
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index 891fc09..1d69625 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -992,13 +992,20 @@
auto* want = a->Result(0)->Type();
auto* want_view = want->As<type::MemoryView>();
- bool ok = ty == want->UnwrapPtrOrRef() && (obj_view == nullptr) == (want_view == nullptr);
- if (ok && obj_view) {
- ok = obj_view->Is<type::Pointer>() == want_view->Is<type::Pointer>() &&
- obj_view->AddressSpace() == want_view->AddressSpace() &&
- obj_view->Access() == want_view->Access();
+ bool ok = true;
+ if (obj_view) {
+ // Pointer source always means pointer result.
+ ok = want_view && ty == want_view->StoreType();
+ if (ok) {
+ // Also check that the address space and access modes match.
+ ok = obj_view->Is<type::Pointer>() == want_view->Is<type::Pointer>() &&
+ obj_view->AddressSpace() == want_view->AddressSpace() &&
+ obj_view->Access() == want_view->Access();
+ }
+ } else {
+ // Otherwise, result types should exactly match.
+ ok = ty == want;
}
-
if (TINT_UNLIKELY(!ok)) {
AddError(a) << "result of access chain is type " << desc_of(in_kind, ty)
<< " but instruction type is " << style::Type(want->FriendlyName());
diff --git a/src/tint/lang/core/ir/validator_test.cc b/src/tint/lang/core/ir/validator_test.cc
index b14f6c3..bc6b744 100644
--- a/src/tint/lang/core/ir/validator_test.cc
+++ b/src/tint/lang/core/ir/validator_test.cc
@@ -1114,6 +1114,25 @@
ASSERT_EQ(res, Success);
}
+TEST_F(IR_ValidatorTest, Access_ExtractPointerFromStruct) {
+ auto* ptr = ty.ptr<private_, i32>();
+ Vector<type::Manager::StructMemberDesc, 1> members{
+ type::Manager::StructMemberDesc{mod.symbols.New("a"), ptr},
+ };
+ auto* str = ty.Struct(mod.symbols.New("MyStruct"), std::move(members));
+ auto* f = b.Function("my_func", ty.void_());
+ auto* obj = b.FunctionParam("obj", str);
+ f->SetParams({obj});
+
+ b.Append(f->Block(), [&] {
+ b.Access(ptr, obj, 0_u);
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_EQ(res, Success);
+}
+
TEST_F(IR_ValidatorTest, Block_TerminatorInMiddle) {
auto* f = b.Function("my_func", ty.void_());