[hlsl] Handle access chains of unnamed access chains
This CL adds support for access chains which use other access chains as
their originating variable.
Bug: 349867642
Change-Id: I617a756d4372696de12242c55f049258f0cdf801
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/196316
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/lang/hlsl/writer/access_test.cc b/src/tint/lang/hlsl/writer/access_test.cc
index 8410e2a..8c9b75c 100644
--- a/src/tint/lang/hlsl/writer/access_test.cc
+++ b/src/tint/lang/hlsl/writer/access_test.cc
@@ -628,5 +628,41 @@
)");
}
+TEST_F(HlslWriterTest, AccessChainFromUnnamedAccessChain) {
+ auto* Inner =
+ ty.Struct(mod.symbols.New("Inner"),
+ {
+ {mod.symbols.New("c"), ty.f32(), core::type::StructMemberAttributes{}},
+ {mod.symbols.New("d"), ty.u32(), core::type::StructMemberAttributes{}},
+ });
+ auto* sb = ty.Struct(mod.symbols.New("SB"),
+ {
+ {mod.symbols.New("a"), ty.i32(), core::type::StructMemberAttributes{}},
+ {mod.symbols.New("b"), Inner, core::type::StructMemberAttributes{}},
+ });
+
+ auto* var = b.Var("v", storage, sb, core::Access::kReadWrite);
+ var->SetBindingPoint(0, 0);
+ b.ir.root_block->Append(var);
+
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ auto* x = b.Access(ty.ptr(storage, sb, core::Access::kReadWrite), var);
+ auto* y = b.Access(ty.ptr(storage, Inner, core::Access::kReadWrite), x->Result(0), 1_u);
+ b.Let("b", b.Load(b.Access(ty.ptr(storage, ty.u32(), core::Access::kReadWrite),
+ y->Result(0), 1_u)));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+RWByteAddressBuffer v : register(u0);
+void foo() {
+ uint b = v.Load(8u);
+}
+
+)");
+}
+
} // namespace
} // namespace tint::hlsl::writer
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_memory_access.cc b/src/tint/lang/hlsl/writer/raise/decompose_memory_access.cc
index 8476ff3..089baa8 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_memory_access.cc
+++ b/src/tint/lang/hlsl/writer/raise/decompose_memory_access.cc
@@ -106,9 +106,9 @@
inst,
[&](core::ir::LoadVectorElement* l) { LoadVectorElement(l, var, var_ty); },
[&](core::ir::StoreVectorElement* s) { StoreVectorElement(s, var, var_ty); },
- [&](core::ir::Store* s) { Store(s); }, //
- [&](core::ir::Load* l) { Load(l, var); }, //
- [&](core::ir::Access* a) { Access(a, var, var_ty); }, //
+ [&](core::ir::Store* s) { Store(s); }, //
+ [&](core::ir::Load* l) { Load(l, var); }, //
+ [&](core::ir::Access* a) { Access(a, var, a->Object()->Type(), 0u); }, //
TINT_ICE_ON_NO_MATCH);
}
@@ -309,14 +309,17 @@
});
}
- void Access(core::ir::Access* a, core::ir::Var* var, const core::type::Pointer*) {
- const core::type::Type* obj = a->Object()->Type();
- auto* view = obj->As<core::type::MemoryView>();
- TINT_ASSERT(view);
+ void Access(core::ir::Access* a,
+ core::ir::Var* var,
+ const core::type::Type* obj,
+ uint32_t byte_offset) {
+ // Note, because we recurse through the `access` helper, the object passed in isn't
+ // necessarily the originating `var` object, but maybe a partially resolved access chain
+ // object.
+ if (auto* view = obj->As<core::type::MemoryView>()) {
+ obj = view->StoreType();
+ }
- obj = view->StoreType();
-
- uint32_t byte_offset = 0;
for (auto* idx_value : a->Indices()) {
auto* cnst = idx_value->As<core::ir::Constant>();
@@ -362,8 +365,11 @@
[&](core::ir::Let*) {
// TODO(dsinclair): handle let
},
- [&](core::ir::Access*) {
- // TODO(dsinclair): Handle access
+ [&](core::ir::Access* sub_access) {
+ // Treat an access chain of the access chain as a continuation of the outer
+ // chain. Pass through the object we stopped at and the current byte_offset and
+ // then restart the access chain replacement for the new access chain.
+ Access(sub_access, var, obj, byte_offset);
},
[&](core::ir::LoadVectorElement* lve) {
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_memory_access_test.cc b/src/tint/lang/hlsl/writer/raise/decompose_memory_access_test.cc
index da0b618..db557f4 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_memory_access_test.cc
+++ b/src/tint/lang/hlsl/writer/raise/decompose_memory_access_test.cc
@@ -177,11 +177,12 @@
EXPECT_EQ(expect, str());
}
-TEST_F(HlslWriterDecomposeMemoryAccessTest, DISABLED_AccessChainFromUnnamedAccessChain) {
+TEST_F(HlslWriterDecomposeMemoryAccessTest, AccessChainFromUnnamedAccessChain) {
auto* Inner =
ty.Struct(mod.symbols.New("Inner"),
{
{mod.symbols.New("c"), ty.f32(), core::type::StructMemberAttributes{}},
+ {mod.symbols.New("d"), ty.u32(), core::type::StructMemberAttributes{}},
});
auto* sb = ty.Struct(mod.symbols.New("SB"),
{
@@ -190,20 +191,22 @@
});
auto* var = b.Var("v", storage, sb, core::Access::kReadWrite);
+ var->SetBindingPoint(0, 0);
b.ir.root_block->Append(var);
auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
b.Append(func->Block(), [&] {
auto* x = b.Access(ty.ptr(storage, sb, core::Access::kReadWrite), var);
auto* y = b.Access(ty.ptr(storage, Inner, core::Access::kReadWrite), x->Result(0), 1_u);
- b.Let("b", b.Load(b.Access(ty.ptr(storage, ty.f32(), core::Access::kReadWrite),
- y->Result(0), 0_u)));
+ b.Let("b", b.Load(b.Access(ty.ptr(storage, ty.u32(), core::Access::kReadWrite),
+ y->Result(0), 1_u)));
b.Return(func);
});
auto* src = R"(
Inner = struct @align(4) {
c:f32 @offset(0)
+ d:u32 @offset(4)
}
SB = struct @align(4) {
@@ -212,16 +215,16 @@
}
$B1: { # root
- %v:ptr<storage, SB, read_write> = var
+ %v:ptr<storage, SB, read_write> = var @binding_point(0, 0)
}
%foo = @fragment func():void {
$B2: {
%3:ptr<storage, SB, read_write> = access %v
%4:ptr<storage, Inner, read_write> = access %3, 1u
- %5:ptr<storage, f32, read_write> = access %4, 0u
- %6:f32 = load %5
- %b:f32 = let %6
+ %5:ptr<storage, u32, read_write> = access %4, 1u
+ %6:u32 = load %5
+ %b:u32 = let %6
ret
}
}
@@ -229,20 +232,25 @@
ASSERT_EQ(src, str());
auto* expect = R"(
-SB = struct @align(16) {
+Inner = struct @align(4) {
+ c:f32 @offset(0)
+ d:u32 @offset(4)
+}
+
+SB = struct @align(4) {
a:i32 @offset(0)
- b:vec3<f32> @offset(16)
+ b:Inner @offset(4)
}
$B1: { # root
- %v:hlsl.byte_address_buffer<read_write> = var
+ %v:hlsl.byte_address_buffer<read_write> = var @binding_point(0, 0)
}
%foo = @fragment func():void {
$B2: {
- %3:vec3<u32> = %v.Load3 16u
- %a:vec3<f32> = bitcast %3
- %b:f32 = %a 1u
+ %3:u32 = %v.Load 8u
+ %4:u32 = bitcast %3
+ %b:u32 = let %4
ret
}
}
@@ -265,13 +273,16 @@
});
auto* var = b.Var("v", storage, sb, core::Access::kReadWrite);
+ var->SetBindingPoint(0, 0);
b.ir.root_block->Append(var);
auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
b.Append(func->Block(), [&] {
- auto* a = b.Let("a", b.Access(ty.ptr(storage, Inner, core::Access::kReadWrite), var, 1_u));
- b.Let("b", b.Load(b.Access(ty.ptr(storage, ty.f32(), core::Access::kReadWrite),
- a->Result(0), 0_u)));
+ auto* x = b.Let("x", b.Access(ty.ptr(storage, Inner, core::Access::kReadWrite), var));
+ auto* y = b.Let(
+ "y", b.Access(ty.ptr(storage, Inner, core::Access::kReadWrite), x->Result(0), 1_u));
+ b.Let("z", b.Load(b.Access(ty.ptr(storage, ty.f32(), core::Access::kReadWrite),
+ y->Result(0), 0_u)));
b.Return(func);
});
@@ -286,7 +297,7 @@
}
$B1: { # root
- %v:ptr<storage, SB, read_write> = var
+ %v:ptr<storage, SB, read_write> = var @binding_point(0, 0)
}
%foo = @fragment func():void {
@@ -309,7 +320,7 @@
}
$B1: { # root
- %v:hlsl.byte_address_buffer<read_write> = var
+ %v:hlsl.byte_address_buffer<read_write> = var @binding_point(0, 0)
}
%foo = @fragment func():void {