[hlsl] Support let pointer to var.
This Cl adds support to the DecomposeMemoryAccess transform to handle a
let which holds a var directly. The `var` is sunk down into any usages
of the `let` and the `let` removed. As similar thing happens to `let`
which directly holds an `access` to one of the `var`s. The `access` is
sunk down through the `let`.
Bug: 349867642
Change-Id: Ia6a090fc01e1b45510751f0d137f634bbd880f2d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/196694
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/hlsl/writer/access_test.cc b/src/tint/lang/hlsl/writer/access_test.cc
index 8c9b75c..3bc2784 100644
--- a/src/tint/lang/hlsl/writer/access_test.cc
+++ b/src/tint/lang/hlsl/writer/access_test.cc
@@ -664,5 +664,42 @@
)");
}
+TEST_F(HlslWriterTest, AccessChainFromLetAccessChain) {
+ auto* Inner =
+ ty.Struct(mod.symbols.New("Inner"),
+ {
+ {mod.symbols.New("c"), ty.f32(), core::type::StructMemberAttributes{}},
+ });
+ auto* sb = ty.Struct(mod.symbols.New("SB"),
+ {
+ {mod.symbols.New("a"), ty.i32(), core::type::StructMemberAttributes{}},
+ {mod.symbols.New("b"), Inner, core::type::StructMemberAttributes{}},
+ });
+
+ auto* var = b.Var("v", storage, sb, core::Access::kReadWrite);
+ var->SetBindingPoint(0, 0);
+ b.ir.root_block->Append(var);
+
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ auto* x = b.Let("x", var);
+ auto* y = b.Let(
+ "y", b.Access(ty.ptr(storage, Inner, core::Access::kReadWrite), x->Result(0), 1_u));
+ auto* z = b.Let(
+ "z", b.Access(ty.ptr(storage, ty.f32(), core::Access::kReadWrite), y->Result(0), 0_u));
+ b.Let("a", b.Load(z));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+RWByteAddressBuffer v : register(u0);
+void foo() {
+ float a = asfloat(v.Load(4u));
+}
+
+)");
+}
+
} // namespace
} // namespace tint::hlsl::writer
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_memory_access.cc b/src/tint/lang/hlsl/writer/raise/decompose_memory_access.cc
index c71c963..f7d2e86 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_memory_access.cc
+++ b/src/tint/lang/hlsl/writer/raise/decompose_memory_access.cc
@@ -92,11 +92,14 @@
[&](core::ir::StoreVectorElement* sve) { usage_worklist.Push(sve); },
[&](core::ir::Store* st) { usage_worklist.Push(st); },
[&](core::ir::Load* ld) { usage_worklist.Push(ld); },
- [&](core::ir::Access* a) { usage_worklist.Push(a); });
+ [&](core::ir::Access* a) { usage_worklist.Push(a); },
+ [&](core::ir::Let* l) { usage_worklist.Push(l); }, //
+ TINT_ICE_ON_NO_MATCH);
}
auto* var_ty = result->Type()->As<core::type::Pointer>();
- for (auto* inst : usage_worklist) {
+ while (!usage_worklist.IsEmpty()) {
+ auto* inst = usage_worklist.Pop();
// Load instructions can be destroyed by the replacing access function
if (!inst->Alive()) {
continue;
@@ -106,9 +109,19 @@
inst,
[&](core::ir::LoadVectorElement* l) { LoadVectorElement(l, var, var_ty); },
[&](core::ir::StoreVectorElement* s) { StoreVectorElement(s, var, var_ty); },
- [&](core::ir::Store* s) { Store(s); }, //
- [&](core::ir::Load* l) { Load(l, var); }, //
- [&](core::ir::Access* a) { Access(a, var, a->Object()->Type(), 0u); }, //
+ [&](core::ir::Store* s) { Store(s); }, //
+ [&](core::ir::Load* l) { Load(l, var); }, //
+ [&](core::ir::Access* a) { Access(a, var, a->Object()->Type(), 0u); },
+ [&](core::ir::Let* let) {
+ // The `let` is, essentially, an alias for the `var` as it's assigned
+ // directly. Gather all the `let` usages into our worklist, and then replace
+ // the `let` with the `var` itself.
+ for (auto& usage : let->Result(0)->Usages()) {
+ usage_worklist.Push(usage->instruction);
+ }
+ let->Result(0)->ReplaceAllUsesWith(result);
+ let->Destroy();
+ },
TINT_ICE_ON_NO_MATCH);
}
@@ -309,6 +322,15 @@
});
}
+ void InsertLoad(core::ir::Var* var, core::ir::Instruction* inst, uint32_t offset) {
+ b.InsertBefore(inst, [&] {
+ auto* call =
+ MakeLoad(inst, var, inst->Result(0)->Type()->UnwrapPtr(), b.Value(u32(offset)));
+ inst->Result(0)->ReplaceAllUsesWith(call->Result(0));
+ });
+ inst->Destroy();
+ }
+
void Access(core::ir::Access* a,
core::ir::Var* var,
const core::type::Type* obj,
@@ -349,21 +371,20 @@
TINT_ICE_ON_NO_MATCH);
}
- auto insert_load = [&](core::ir::Instruction* inst, uint32_t offset) {
- b.InsertBefore(inst, [&] {
- auto* call = MakeLoad(inst, var, inst->Result(0)->Type(), b.Value(u32(offset)));
- inst->Result(0)->ReplaceAllUsesWith(call->Result(0));
- });
- inst->Destroy();
- };
-
// Copy the usages into a vector so we can remove items from the hashset.
auto usages = a->Result(0)->Usages().Vector();
- for (auto& usage : usages) {
+ while (!usages.IsEmpty()) {
+ auto usage = usages.Pop();
tint::Switch(
- usage.instruction, //
- [&](core::ir::Let*) {
- // TODO(dsinclair): handle let
+ usage.instruction,
+ [&](core::ir::Let* let) {
+ // The `let` is essentially an alias to the `access`. So, add the `let` usages
+ // into the usage worklist, and replace the let with the access chain directly.
+ for (auto& u : let->Result(0)->Usages()) {
+ usages.Push(u);
+ }
+ let->Result(0)->ReplaceAllUsesWith(a->Result(0));
+ let->Destroy();
},
[&](core::ir::Access* sub_access) {
// Treat an access chain of the access chain as a continuation of the outer
@@ -376,11 +397,11 @@
a->Result(0)->RemoveUsage(usage);
byte_offset += CalculateVectorIndex(lve->Index(), obj);
- insert_load(lve, byte_offset);
- }, //
+ InsertLoad(var, lve, byte_offset);
+ },
[&](core::ir::Load* ld) {
a->Result(0)->RemoveUsage(usage);
- insert_load(ld, byte_offset);
+ InsertLoad(var, ld, byte_offset);
},
[&](core::ir::StoreVectorElement*) {
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_memory_access_test.cc b/src/tint/lang/hlsl/writer/raise/decompose_memory_access_test.cc
index db557f4..9d4a22b 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_memory_access_test.cc
+++ b/src/tint/lang/hlsl/writer/raise/decompose_memory_access_test.cc
@@ -260,7 +260,7 @@
EXPECT_EQ(expect, str());
}
-TEST_F(HlslWriterDecomposeMemoryAccessTest, DISABLED_AccessChainFromLetAccessChain) {
+TEST_F(HlslWriterDecomposeMemoryAccessTest, AccessChainFromLetAccessChain) {
auto* Inner =
ty.Struct(mod.symbols.New("Inner"),
{
@@ -278,11 +278,12 @@
auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
b.Append(func->Block(), [&] {
- auto* x = b.Let("x", b.Access(ty.ptr(storage, Inner, core::Access::kReadWrite), var));
+ auto* x = b.Let("x", var);
auto* y = b.Let(
"y", b.Access(ty.ptr(storage, Inner, core::Access::kReadWrite), x->Result(0), 1_u));
- b.Let("z", b.Load(b.Access(ty.ptr(storage, ty.f32(), core::Access::kReadWrite),
- y->Result(0), 0_u)));
+ auto* z = b.Let(
+ "z", b.Access(ty.ptr(storage, ty.f32(), core::Access::kReadWrite), y->Result(0), 0_u));
+ b.Let("a", b.Load(z));
b.Return(func);
});
@@ -302,11 +303,13 @@
%foo = @fragment func():void {
$B2: {
- %3:ptr<storage, Inner, read_write> = access %v, 1u
- %a:ptr<storage, Inner, read_write> = let %3
- %5:ptr<storage, f32, read_write> = access %a, 0u
- %6:f32 = load %5
- %b:f32 = let %6
+ %x:ptr<storage, SB, read_write> = let %v
+ %4:ptr<storage, Inner, read_write> = access %x, 1u
+ %y:ptr<storage, Inner, read_write> = let %4
+ %6:ptr<storage, f32, read_write> = access %y, 0u
+ %z:ptr<storage, f32, read_write> = let %6
+ %8:f32 = load %z
+ %a:f32 = let %8
ret
}
}
@@ -314,9 +317,13 @@
ASSERT_EQ(src, str());
auto* expect = R"(
-SB = struct @align(16) {
+Inner = struct @align(4) {
+ c:f32 @offset(0)
+}
+
+SB = struct @align(4) {
a:i32 @offset(0)
- b:vec3<f32> @offset(16)
+ b:Inner @offset(4)
}
$B1: { # root
@@ -325,9 +332,9 @@
%foo = @fragment func():void {
$B2: {
- %3:vec3<u32> = %v.Load3 16u
- %a:vec3<f32> = bitcast %3
- %b:f32 = %a 1u
+ %3:u32 = %v.Load 4u
+ %4:f32 = bitcast %3
+ %a:f32 = let %4
ret
}
}