[hlsl] Add load support into decompose memory access.
This Cl starts adding support for `load` HLSL instructions in the
decompose memory access transform.
Bug: 349867642
Change-Id: If5c3f16aea6afbc8a677544bec5c9d6230d20a3c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/196454
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index 646c246..94e0ff2 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -1103,7 +1103,9 @@
void Validator::CheckAccess(const Access* a) {
auto* obj_view = a->Object()->Type()->As<core::type::MemoryView>();
auto* ty = obj_view ? obj_view->StoreType() : a->Object()->Type();
- enum Kind { kPtr, kRef, kValue };
+
+ enum Kind : uint8_t { kPtr, kRef, kValue };
+
auto kind_of = [&](const core::type::Type* type) {
return tint::Switch(
type, //
@@ -1111,6 +1113,7 @@
[&](const core::type::Reference*) { return kRef; }, //
[&](Default) { return kValue; });
};
+
const Kind in_kind = kind_of(a->Object()->Type());
auto desc_of = [&](Kind kind, const core::type::Type* type) {
switch (kind) {
diff --git a/src/tint/lang/hlsl/builtin_fn.cc b/src/tint/lang/hlsl/builtin_fn.cc
index 81817bd..091e8d2 100644
--- a/src/tint/lang/hlsl/builtin_fn.cc
+++ b/src/tint/lang/hlsl/builtin_fn.cc
@@ -60,6 +60,14 @@
return "Load3";
case BuiltinFn::kLoad4:
return "Load4";
+ case BuiltinFn::kLoadF16:
+ return "LoadF16";
+ case BuiltinFn::kLoad2F16:
+ return "Load2F16";
+ case BuiltinFn::kLoad3F16:
+ return "Load3F16";
+ case BuiltinFn::kLoad4F16:
+ return "Load4F16";
case BuiltinFn::kStore:
return "Store";
case BuiltinFn::kStore2:
diff --git a/src/tint/lang/hlsl/builtin_fn.h b/src/tint/lang/hlsl/builtin_fn.h
index 7d40b62..f9f0d3f 100644
--- a/src/tint/lang/hlsl/builtin_fn.h
+++ b/src/tint/lang/hlsl/builtin_fn.h
@@ -56,6 +56,10 @@
kLoad2,
kLoad3,
kLoad4,
+ kLoadF16,
+ kLoad2F16,
+ kLoad3F16,
+ kLoad4F16,
kStore,
kStore2,
kStore3,
diff --git a/src/tint/lang/hlsl/hlsl.def b/src/tint/lang/hlsl/hlsl.def
index 49b7e6a..6adb7ca 100644
--- a/src/tint/lang/hlsl/hlsl.def
+++ b/src/tint/lang/hlsl/hlsl.def
@@ -52,7 +52,7 @@
type vec4<T>
@display("vec{N}<{T}>") type vec<N: num, T>
-type byte_address_buffer<T, A: access>
+type byte_address_buffer<A: access>
////////////////////////////////////////////////////////////////////////////////
// Type matchers //
@@ -89,12 +89,17 @@
fn f16tof32(u32) -> f32
fn f16tof32[N: num](vec<N, u32>) -> vec<N, f32>
-@member_function fn Load[T](byte_address_buffer<T, readable>, offset: u32) -> u32
-@member_function fn Load2[T](byte_address_buffer<T, readable>, offset: u32) -> vec2<u32>
-@member_function fn Load3[T](byte_address_buffer<T, readable>, offset: u32) -> vec3<u32>
-@member_function fn Load4[T](byte_address_buffer<T, readable>, offset: u32) -> vec4<u32>
+@member_function fn Load(byte_address_buffer<readable>, offset: u32) -> u32
+@member_function fn Load2(byte_address_buffer<readable>, offset: u32) -> vec2<u32>
+@member_function fn Load3(byte_address_buffer<readable>, offset: u32) -> vec3<u32>
+@member_function fn Load4(byte_address_buffer<readable>, offset: u32) -> vec4<u32>
-@member_function fn Store[T](byte_address_buffer<T, writable>, offset: u32, value: u32)
-@member_function fn Store2[T](byte_address_buffer<T, writable>, offset: u32, value: vec2<u32>)
-@member_function fn Store3[T](byte_address_buffer<T, writable>, offset: u32, value: vec3<u32>)
-@member_function fn Store4[T](byte_address_buffer<T, writable>, offset: u32, value: vec4<u32>)
+@member_function fn LoadF16(byte_address_buffer<readable>, offset: u32) -> f16
+@member_function fn Load2F16(byte_address_buffer<readable>, offset: u32) -> vec2<f16>
+@member_function fn Load3F16(byte_address_buffer<readable>, offset: u32) -> vec3<f16>
+@member_function fn Load4F16(byte_address_buffer<readable>, offset: u32) -> vec4<f16>
+
+@member_function fn Store(byte_address_buffer<writable>, offset: u32, value: u32)
+@member_function fn Store2(byte_address_buffer<writable>, offset: u32, value: vec2<u32>)
+@member_function fn Store3(byte_address_buffer<writable>, offset: u32, value: vec3<u32>)
+@member_function fn Store4(byte_address_buffer<writable>, offset: u32, value: vec4<u32>)
diff --git a/src/tint/lang/hlsl/intrinsic/data.cc b/src/tint/lang/hlsl/intrinsic/data.cc
index c0750db..5780365 100644
--- a/src/tint/lang/hlsl/intrinsic/data.cc
+++ b/src/tint/lang/hlsl/intrinsic/data.cc
@@ -252,25 +252,19 @@
/// TypeMatcher for 'type byte_address_buffer'
constexpr TypeMatcher kByteAddressBufferMatcher {
/* match */ [](MatchState& state, const Type* ty) -> const Type* {
- const Type* T = nullptr;
Number A = Number::invalid;
- if (!MatchByteAddressBuffer(state, ty, T, A)) {
- return nullptr;
- }
- T = state.Type(T);
- if (T == nullptr) {
+ if (!MatchByteAddressBuffer(state, ty, A)) {
return nullptr;
}
A = state.Num(A);
if (!A.IsValid()) {
return nullptr;
}
- return BuildByteAddressBuffer(state, ty, T, A);
+ return BuildByteAddressBuffer(state, ty, A);
},
-/* print */ []([[maybe_unused]] MatchState* state, StyledText& out) {StyledText T;
- state->PrintType(T);StyledText A;
+/* print */ []([[maybe_unused]] MatchState* state, StyledText& out) {StyledText A;
state->PrintNum(A);
- out << style::Type("byte_address_buffer", "<", T, ", ", A, ">");
+ out << style::Type("byte_address_buffer", "<", A, ">");
}
};
@@ -421,20 +415,24 @@
/* [16] */ MatcherIndex(0),
/* [17] */ MatcherIndex(4),
/* [18] */ MatcherIndex(11),
- /* [19] */ MatcherIndex(0),
- /* [20] */ MatcherIndex(3),
- /* [21] */ MatcherIndex(11),
- /* [22] */ MatcherIndex(0),
- /* [23] */ MatcherIndex(4),
- /* [24] */ MatcherIndex(7),
+ /* [19] */ MatcherIndex(3),
+ /* [20] */ MatcherIndex(7),
+ /* [21] */ MatcherIndex(3),
+ /* [22] */ MatcherIndex(8),
+ /* [23] */ MatcherIndex(3),
+ /* [24] */ MatcherIndex(9),
/* [25] */ MatcherIndex(3),
- /* [26] */ MatcherIndex(8),
- /* [27] */ MatcherIndex(3),
- /* [28] */ MatcherIndex(9),
- /* [29] */ MatcherIndex(3),
- /* [30] */ MatcherIndex(13),
- /* [31] */ MatcherIndex(14),
- /* [32] */ MatcherIndex(12),
+ /* [26] */ MatcherIndex(7),
+ /* [27] */ MatcherIndex(5),
+ /* [28] */ MatcherIndex(8),
+ /* [29] */ MatcherIndex(5),
+ /* [30] */ MatcherIndex(9),
+ /* [31] */ MatcherIndex(5),
+ /* [32] */ MatcherIndex(11),
+ /* [33] */ MatcherIndex(4),
+ /* [34] */ MatcherIndex(13),
+ /* [35] */ MatcherIndex(14),
+ /* [36] */ MatcherIndex(12),
};
static_assert(MatcherIndicesIndex::CanIndex(kMatcherIndices),
@@ -444,7 +442,7 @@
{
/* [0] */
/* usage */ core::ParameterUsage::kNone,
- /* matcher_indices */ MatcherIndicesIndex(21),
+ /* matcher_indices */ MatcherIndicesIndex(32),
},
{
/* [1] */
@@ -459,7 +457,7 @@
{
/* [3] */
/* usage */ core::ParameterUsage::kNone,
- /* matcher_indices */ MatcherIndicesIndex(21),
+ /* matcher_indices */ MatcherIndicesIndex(32),
},
{
/* [4] */
@@ -469,12 +467,12 @@
{
/* [5] */
/* usage */ core::ParameterUsage::kValue,
- /* matcher_indices */ MatcherIndicesIndex(24),
+ /* matcher_indices */ MatcherIndicesIndex(20),
},
{
/* [6] */
/* usage */ core::ParameterUsage::kNone,
- /* matcher_indices */ MatcherIndicesIndex(21),
+ /* matcher_indices */ MatcherIndicesIndex(32),
},
{
/* [7] */
@@ -484,12 +482,12 @@
{
/* [8] */
/* usage */ core::ParameterUsage::kValue,
- /* matcher_indices */ MatcherIndicesIndex(26),
+ /* matcher_indices */ MatcherIndicesIndex(22),
},
{
/* [9] */
/* usage */ core::ParameterUsage::kNone,
- /* matcher_indices */ MatcherIndicesIndex(21),
+ /* matcher_indices */ MatcherIndicesIndex(32),
},
{
/* [10] */
@@ -499,7 +497,7 @@
{
/* [11] */
/* usage */ core::ParameterUsage::kValue,
- /* matcher_indices */ MatcherIndicesIndex(28),
+ /* matcher_indices */ MatcherIndicesIndex(24),
},
{
/* [12] */
@@ -550,7 +548,7 @@
{
/* [0] */
/* name */ "T",
- /* matcher_indices */ MatcherIndicesIndex(30),
+ /* matcher_indices */ MatcherIndicesIndex(34),
/* kind */ TemplateInfo::Kind::kType,
},
{
@@ -562,7 +560,7 @@
{
/* [2] */
/* name */ "T",
- /* matcher_indices */ MatcherIndicesIndex(31),
+ /* matcher_indices */ MatcherIndicesIndex(35),
/* kind */ TemplateInfo::Kind::kType,
},
{
@@ -574,7 +572,7 @@
{
/* [4] */
/* name */ "T",
- /* matcher_indices */ MatcherIndicesIndex(32),
+ /* matcher_indices */ MatcherIndicesIndex(36),
/* kind */ TemplateInfo::Kind::kType,
},
{
@@ -583,12 +581,6 @@
/* matcher_indices */ MatcherIndicesIndex(/* invalid */),
/* kind */ TemplateInfo::Kind::kNumber,
},
- {
- /* [6] */
- /* name */ "T",
- /* matcher_indices */ MatcherIndicesIndex(/* invalid */),
- /* kind */ TemplateInfo::Kind::kType,
- },
};
static_assert(TemplateIndex::CanIndex(kTemplates),
@@ -710,8 +702,8 @@
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline, OverloadFlag::kMemberFunction),
/* num_parameters */ 2,
/* num_explicit_templates */ 0,
- /* num_templates */ 1,
- /* templates */ TemplateIndex(6),
+ /* num_templates */ 0,
+ /* templates */ TemplateIndex(/* invalid */),
/* parameters */ ParameterIndex(12),
/* return_matcher_indices */ MatcherIndicesIndex(8),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
@@ -721,10 +713,10 @@
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline, OverloadFlag::kMemberFunction),
/* num_parameters */ 2,
/* num_explicit_templates */ 0,
- /* num_templates */ 1,
- /* templates */ TemplateIndex(6),
+ /* num_templates */ 0,
+ /* templates */ TemplateIndex(/* invalid */),
/* parameters */ ParameterIndex(12),
- /* return_matcher_indices */ MatcherIndicesIndex(24),
+ /* return_matcher_indices */ MatcherIndicesIndex(20),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
@@ -732,10 +724,10 @@
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline, OverloadFlag::kMemberFunction),
/* num_parameters */ 2,
/* num_explicit_templates */ 0,
- /* num_templates */ 1,
- /* templates */ TemplateIndex(6),
+ /* num_templates */ 0,
+ /* templates */ TemplateIndex(/* invalid */),
/* parameters */ ParameterIndex(12),
- /* return_matcher_indices */ MatcherIndicesIndex(26),
+ /* return_matcher_indices */ MatcherIndicesIndex(22),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
@@ -743,52 +735,96 @@
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline, OverloadFlag::kMemberFunction),
/* num_parameters */ 2,
/* num_explicit_templates */ 0,
- /* num_templates */ 1,
- /* templates */ TemplateIndex(6),
+ /* num_templates */ 0,
+ /* templates */ TemplateIndex(/* invalid */),
/* parameters */ ParameterIndex(12),
- /* return_matcher_indices */ MatcherIndicesIndex(28),
+ /* return_matcher_indices */ MatcherIndicesIndex(24),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
/* [14] */
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline, OverloadFlag::kMemberFunction),
- /* num_parameters */ 3,
+ /* num_parameters */ 2,
/* num_explicit_templates */ 0,
- /* num_templates */ 1,
- /* templates */ TemplateIndex(6),
- /* parameters */ ParameterIndex(0),
- /* return_matcher_indices */ MatcherIndicesIndex(/* invalid */),
+ /* num_templates */ 0,
+ /* templates */ TemplateIndex(/* invalid */),
+ /* parameters */ ParameterIndex(12),
+ /* return_matcher_indices */ MatcherIndicesIndex(27),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
/* [15] */
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline, OverloadFlag::kMemberFunction),
- /* num_parameters */ 3,
+ /* num_parameters */ 2,
/* num_explicit_templates */ 0,
- /* num_templates */ 1,
- /* templates */ TemplateIndex(6),
- /* parameters */ ParameterIndex(3),
- /* return_matcher_indices */ MatcherIndicesIndex(/* invalid */),
+ /* num_templates */ 0,
+ /* templates */ TemplateIndex(/* invalid */),
+ /* parameters */ ParameterIndex(12),
+ /* return_matcher_indices */ MatcherIndicesIndex(26),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
/* [16] */
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline, OverloadFlag::kMemberFunction),
- /* num_parameters */ 3,
+ /* num_parameters */ 2,
/* num_explicit_templates */ 0,
- /* num_templates */ 1,
- /* templates */ TemplateIndex(6),
- /* parameters */ ParameterIndex(6),
- /* return_matcher_indices */ MatcherIndicesIndex(/* invalid */),
+ /* num_templates */ 0,
+ /* templates */ TemplateIndex(/* invalid */),
+ /* parameters */ ParameterIndex(12),
+ /* return_matcher_indices */ MatcherIndicesIndex(28),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
/* [17] */
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline, OverloadFlag::kMemberFunction),
+ /* num_parameters */ 2,
+ /* num_explicit_templates */ 0,
+ /* num_templates */ 0,
+ /* templates */ TemplateIndex(/* invalid */),
+ /* parameters */ ParameterIndex(12),
+ /* return_matcher_indices */ MatcherIndicesIndex(30),
+ /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
+ },
+ {
+ /* [18] */
+ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline, OverloadFlag::kMemberFunction),
/* num_parameters */ 3,
/* num_explicit_templates */ 0,
- /* num_templates */ 1,
- /* templates */ TemplateIndex(6),
+ /* num_templates */ 0,
+ /* templates */ TemplateIndex(/* invalid */),
+ /* parameters */ ParameterIndex(0),
+ /* return_matcher_indices */ MatcherIndicesIndex(/* invalid */),
+ /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
+ },
+ {
+ /* [19] */
+ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline, OverloadFlag::kMemberFunction),
+ /* num_parameters */ 3,
+ /* num_explicit_templates */ 0,
+ /* num_templates */ 0,
+ /* templates */ TemplateIndex(/* invalid */),
+ /* parameters */ ParameterIndex(3),
+ /* return_matcher_indices */ MatcherIndicesIndex(/* invalid */),
+ /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
+ },
+ {
+ /* [20] */
+ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline, OverloadFlag::kMemberFunction),
+ /* num_parameters */ 3,
+ /* num_explicit_templates */ 0,
+ /* num_templates */ 0,
+ /* templates */ TemplateIndex(/* invalid */),
+ /* parameters */ ParameterIndex(6),
+ /* return_matcher_indices */ MatcherIndicesIndex(/* invalid */),
+ /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
+ },
+ {
+ /* [21] */
+ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline, OverloadFlag::kMemberFunction),
+ /* num_parameters */ 3,
+ /* num_explicit_templates */ 0,
+ /* num_templates */ 0,
+ /* templates */ TemplateIndex(/* invalid */),
/* parameters */ ParameterIndex(9),
/* return_matcher_indices */ MatcherIndicesIndex(/* invalid */),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
@@ -836,52 +872,76 @@
},
{
/* [5] */
- /* fn Load[T](byte_address_buffer<T, readable>, offset: u32) -> u32 */
+ /* fn Load(byte_address_buffer<readable>, offset: u32) -> u32 */
/* num overloads */ 1,
/* overloads */ OverloadIndex(10),
},
{
/* [6] */
- /* fn Load2[T](byte_address_buffer<T, readable>, offset: u32) -> vec2<u32> */
+ /* fn Load2(byte_address_buffer<readable>, offset: u32) -> vec2<u32> */
/* num overloads */ 1,
/* overloads */ OverloadIndex(11),
},
{
/* [7] */
- /* fn Load3[T](byte_address_buffer<T, readable>, offset: u32) -> vec3<u32> */
+ /* fn Load3(byte_address_buffer<readable>, offset: u32) -> vec3<u32> */
/* num overloads */ 1,
/* overloads */ OverloadIndex(12),
},
{
/* [8] */
- /* fn Load4[T](byte_address_buffer<T, readable>, offset: u32) -> vec4<u32> */
+ /* fn Load4(byte_address_buffer<readable>, offset: u32) -> vec4<u32> */
/* num overloads */ 1,
/* overloads */ OverloadIndex(13),
},
{
/* [9] */
- /* fn Store[T](byte_address_buffer<T, writable>, offset: u32, value: u32) */
+ /* fn LoadF16(byte_address_buffer<readable>, offset: u32) -> f16 */
/* num overloads */ 1,
/* overloads */ OverloadIndex(14),
},
{
/* [10] */
- /* fn Store2[T](byte_address_buffer<T, writable>, offset: u32, value: vec2<u32>) */
+ /* fn Load2F16(byte_address_buffer<readable>, offset: u32) -> vec2<f16> */
/* num overloads */ 1,
/* overloads */ OverloadIndex(15),
},
{
/* [11] */
- /* fn Store3[T](byte_address_buffer<T, writable>, offset: u32, value: vec3<u32>) */
+ /* fn Load3F16(byte_address_buffer<readable>, offset: u32) -> vec3<f16> */
/* num overloads */ 1,
/* overloads */ OverloadIndex(16),
},
{
/* [12] */
- /* fn Store4[T](byte_address_buffer<T, writable>, offset: u32, value: vec4<u32>) */
+ /* fn Load4F16(byte_address_buffer<readable>, offset: u32) -> vec4<f16> */
/* num overloads */ 1,
/* overloads */ OverloadIndex(17),
},
+ {
+ /* [13] */
+ /* fn Store(byte_address_buffer<writable>, offset: u32, value: u32) */
+ /* num overloads */ 1,
+ /* overloads */ OverloadIndex(18),
+ },
+ {
+ /* [14] */
+ /* fn Store2(byte_address_buffer<writable>, offset: u32, value: vec2<u32>) */
+ /* num overloads */ 1,
+ /* overloads */ OverloadIndex(19),
+ },
+ {
+ /* [15] */
+ /* fn Store3(byte_address_buffer<writable>, offset: u32, value: vec3<u32>) */
+ /* num overloads */ 1,
+ /* overloads */ OverloadIndex(20),
+ },
+ {
+ /* [16] */
+ /* fn Store4(byte_address_buffer<writable>, offset: u32, value: vec4<u32>) */
+ /* num overloads */ 1,
+ /* overloads */ OverloadIndex(21),
+ },
};
// clang-format on
diff --git a/src/tint/lang/hlsl/intrinsic/type_matchers.h b/src/tint/lang/hlsl/intrinsic/type_matchers.h
index 7567296..e99c48b 100644
--- a/src/tint/lang/hlsl/intrinsic/type_matchers.h
+++ b/src/tint/lang/hlsl/intrinsic/type_matchers.h
@@ -36,10 +36,8 @@
inline bool MatchByteAddressBuffer(core::intrinsic::MatchState&,
const core::type::Type* ty,
- const core::type::Type*& T,
core::intrinsic::Number& A) {
if (auto* buf = ty->As<type::ByteAddressBuffer>()) {
- T = buf->StoreType();
A = core::intrinsic::Number(static_cast<uint32_t>(buf->Access()));
return true;
}
@@ -48,9 +46,8 @@
inline const type::ByteAddressBuffer* BuildByteAddressBuffer(core::intrinsic::MatchState& state,
const core::type::Type*,
- const core::type::Type* T,
core::intrinsic::Number& A) {
- return state.types.Get<type::ByteAddressBuffer>(T, static_cast<core::Access>(A.Value()));
+ return state.types.Get<type::ByteAddressBuffer>(static_cast<core::Access>(A.Value()));
}
} // namespace tint::hlsl::intrinsic
diff --git a/src/tint/lang/hlsl/ir/member_builtin_call_test.cc b/src/tint/lang/hlsl/ir/member_builtin_call_test.cc
index 86ff37d..e886929 100644
--- a/src/tint/lang/hlsl/ir/member_builtin_call_test.cc
+++ b/src/tint/lang/hlsl/ir/member_builtin_call_test.cc
@@ -50,7 +50,7 @@
using IR_HlslMemberBuiltinCallTest = core::ir::IRTestHelper;
TEST_F(IR_HlslMemberBuiltinCallTest, Clone) {
- auto* buf = ty.Get<hlsl::type::ByteAddressBuffer>(ty.vec3<i32>(), core::Access::kReadWrite);
+ auto* buf = ty.Get<hlsl::type::ByteAddressBuffer>(core::Access::kReadWrite);
auto* t = b.FunctionParam("t", buf);
auto* builtin = b.MemberCall<MemberBuiltinCall>(mod.Types().u32(), BuiltinFn::kLoad, t, 2_u);
@@ -71,7 +71,7 @@
}
TEST_F(IR_HlslMemberBuiltinCallTest, DoesNotMatchNonMemberFunction) {
- auto* buf = ty.Get<hlsl::type::ByteAddressBuffer>(ty.vec3<i32>(), core::Access::kRead);
+ auto* buf = ty.Get<hlsl::type::ByteAddressBuffer>(core::Access::kRead);
auto* t = b.FunctionParam("t", buf);
@@ -87,7 +87,7 @@
ASSERT_NE(res, Success);
EXPECT_EQ(
res.Failure().reason.Str(),
- R"(:3:17 error: asint: no matching call to 'asint(hlsl.byte_address_buffer<vec3<i32>, read>, u32)'
+ R"(:3:17 error: asint: no matching call to 'asint(hlsl.byte_address_buffer<read>, u32)'
%3:u32 = %t.asint 2u
^^^^^
@@ -97,7 +97,7 @@
^^^
note: # Disassembly
-%foo = func(%t:hlsl.byte_address_buffer<vec3<i32>, read>):u32 {
+%foo = func(%t:hlsl.byte_address_buffer<read>):u32 {
$B1: {
%3:u32 = %t.asint 2u
ret %3
@@ -107,7 +107,7 @@
}
TEST_F(IR_HlslMemberBuiltinCallTest, DoesNotMatchIncorrectType) {
- auto* buf = ty.Get<hlsl::type::ByteAddressBuffer>(ty.vec3<i32>(), core::Access::kRead);
+ auto* buf = ty.Get<hlsl::type::ByteAddressBuffer>(core::Access::kRead);
auto* t = b.FunctionParam("t", buf);
@@ -123,10 +123,10 @@
ASSERT_NE(res, Success);
EXPECT_EQ(
res.Failure().reason.Str(),
- R"(:3:17 error: Store: no matching call to 'Store(hlsl.byte_address_buffer<vec3<i32>, read>, u32, u32)'
+ R"(:3:17 error: Store: no matching call to 'Store(hlsl.byte_address_buffer<read>, u32, u32)'
1 candidate function:
- • 'Store(byte_address_buffer<T, write' or 'read_write> ✗ , offset: u32 ✓ , value: u32 ✓ )'
+ • 'Store(byte_address_buffer<write' or 'read_write> ✗ , offset: u32 ✓ , value: u32 ✓ )'
%3:u32 = %t.Store 2u, 2u
^^^^^
@@ -136,7 +136,7 @@
^^^
note: # Disassembly
-%foo = func(%t:hlsl.byte_address_buffer<vec3<i32>, read>):u32 {
+%foo = func(%t:hlsl.byte_address_buffer<read>):u32 {
$B1: {
%3:u32 = %t.Store 2u, 2u
ret %3
diff --git a/src/tint/lang/hlsl/type/byte_address_buffer.cc b/src/tint/lang/hlsl/type/byte_address_buffer.cc
index 33a271b..289316c 100644
--- a/src/tint/lang/hlsl/type/byte_address_buffer.cc
+++ b/src/tint/lang/hlsl/type/byte_address_buffer.cc
@@ -40,28 +40,27 @@
namespace tint::hlsl::type {
-ByteAddressBuffer::ByteAddressBuffer(const core::type::Type* source_type, core::Access access)
+ByteAddressBuffer::ByteAddressBuffer(core::Access access)
: Base(static_cast<size_t>(Hash(tint::TypeCode::Of<ByteAddressBuffer>().bits)),
core::AddressSpace::kStorage,
- source_type,
+ nullptr,
access) {}
bool ByteAddressBuffer::Equals(const UniqueNode& other) const {
if (auto* o = other.As<ByteAddressBuffer>()) {
- return o->StoreType() == StoreType() && o->Access() == Access();
+ return o->Access() == Access();
}
return false;
}
std::string ByteAddressBuffer::FriendlyName() const {
StringStream out;
- out << "hlsl.byte_address_buffer<" << StoreType()->FriendlyName() << ", " << Access() << ">";
+ out << "hlsl.byte_address_buffer<" << Access() << ">";
return out.str();
}
ByteAddressBuffer* ByteAddressBuffer::Clone(core::type::CloneContext& ctx) const {
- auto* ty = StoreType()->Clone(ctx);
- return ctx.dst.mgr->Get<ByteAddressBuffer>(ty, Access());
+ return ctx.dst.mgr->Get<ByteAddressBuffer>(Access());
}
} // namespace tint::hlsl::type
diff --git a/src/tint/lang/hlsl/type/byte_address_buffer.h b/src/tint/lang/hlsl/type/byte_address_buffer.h
index 990a056..a158581 100644
--- a/src/tint/lang/hlsl/type/byte_address_buffer.h
+++ b/src/tint/lang/hlsl/type/byte_address_buffer.h
@@ -39,9 +39,8 @@
class ByteAddressBuffer final : public Castable<ByteAddressBuffer, core::type::MemoryView> {
public:
/// Constructor
- /// @param store_type the source buffer type
/// @param access the buffer access mode
- explicit ByteAddressBuffer(const core::type::Type* store_type, core::Access access);
+ explicit ByteAddressBuffer(core::Access access);
/// @param other the other node to compare against
/// @returns true if the this type is equal to @p other
@@ -53,6 +52,8 @@
/// @param ctx the clone context
/// @returns a clone of this type
ByteAddressBuffer* Clone(core::type::CloneContext& ctx) const override;
+
+ const Type* StoreType() const = delete;
};
} // namespace tint::hlsl::type
diff --git a/src/tint/lang/hlsl/type/byte_address_buffer_test.cc b/src/tint/lang/hlsl/type/byte_address_buffer_test.cc
index b4d7ce4..c39f4c2 100644
--- a/src/tint/lang/hlsl/type/byte_address_buffer_test.cc
+++ b/src/tint/lang/hlsl/type/byte_address_buffer_test.cc
@@ -36,25 +36,18 @@
namespace {
TEST(HlslTypeByteAddressBuffer, Equals) {
- core::type::F32 f{};
- core::type::I32 i{};
-
- const ByteAddressBuffer a(&f, core::Access::kRead);
- const ByteAddressBuffer b(&f, core::Access::kRead);
- const ByteAddressBuffer c(&f, core::Access::kReadWrite);
- const ByteAddressBuffer d(&i, core::Access::kRead);
+ const ByteAddressBuffer a(core::Access::kRead);
+ const ByteAddressBuffer b(core::Access::kRead);
+ const ByteAddressBuffer c(core::Access::kReadWrite);
EXPECT_TRUE(a.Equals(b));
EXPECT_FALSE(a.Equals(c));
- EXPECT_FALSE(a.Equals(d));
- EXPECT_FALSE(a.Equals(i));
}
TEST(HlslTypeByteAddressBuffer, FriendlyName) {
- core::type::F32 f{};
- const ByteAddressBuffer l(&f, core::Access::kReadWrite);
+ const ByteAddressBuffer l(core::Access::kReadWrite);
- EXPECT_EQ(l.FriendlyName(), "hlsl.byte_address_buffer<f32, read_write>");
+ EXPECT_EQ(l.FriendlyName(), "hlsl.byte_address_buffer<read_write>");
}
} // namespace
diff --git a/src/tint/lang/hlsl/writer/access_test.cc b/src/tint/lang/hlsl/writer/access_test.cc
index 2b1d895..d878f14 100644
--- a/src/tint/lang/hlsl/writer/access_test.cc
+++ b/src/tint/lang/hlsl/writer/access_test.cc
@@ -269,17 +269,18 @@
)");
}
-TEST_F(HlslWriterTest, AccessVectorLoad) {
+TEST_F(HlslWriterTest, AccessStorageVector) {
auto* var = b.Var<storage, vec4<f32>, core::Access::kRead>("v");
var->SetBindingPoint(0, 0);
b.ir.root_block->Append(var);
auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
b.Append(func->Block(), [&] {
- b.Let("a", b.LoadVectorElement(var, 0_u));
- b.Let("b", b.LoadVectorElement(var, 1_u));
- b.Let("c", b.LoadVectorElement(var, 2_u));
- b.Let("d", b.LoadVectorElement(var, 3_u));
+ b.Let("a", b.Load(var));
+ b.Let("b", b.LoadVectorElement(var, 0_u));
+ b.Let("c", b.LoadVectorElement(var, 1_u));
+ b.Let("d", b.LoadVectorElement(var, 2_u));
+ b.Let("e", b.LoadVectorElement(var, 3_u));
b.Return(func);
});
@@ -287,16 +288,272 @@
EXPECT_EQ(output_.hlsl, R"(
ByteAddressBuffer v : register(t0);
void foo() {
- float a = asfloat(v.Load(0u));
- float b = asfloat(v.Load(4u));
- float c = asfloat(v.Load(8u));
- float d = asfloat(v.Load(12u));
+ float4 a = asfloat(v.Load4(0u));
+ float b = asfloat(v.Load(0u));
+ float c = asfloat(v.Load(4u));
+ float d = asfloat(v.Load(8u));
+ float e = asfloat(v.Load(12u));
}
)");
}
-TEST_F(HlslWriterTest, AccessVectorStore) {
+TEST_F(HlslWriterTest, AccessStorageVectorF16) {
+ auto* var = b.Var<storage, vec4<f16>, core::Access::kRead>("v");
+ var->SetBindingPoint(0, 0);
+
+ b.ir.root_block->Append(var);
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Load(var));
+ b.Let("b", b.LoadVectorElement(var, 0_u));
+ b.Let("c", b.LoadVectorElement(var, 1_u));
+ b.Let("d", b.LoadVectorElement(var, 2_u));
+ b.Let("e", b.LoadVectorElement(var, 3_u));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+ByteAddressBuffer v : register(t0);
+void foo() {
+ vector<float16_t, 4> a = v.Load4<vector<float16_t, 4>>(0u);
+ float16_t b = v.Load<float16_t>(0u);
+ float16_t c = v.Load<float16_t>(2u);
+ float16_t d = v.Load<float16_t>(4u);
+ float16_t e = v.Load<float16_t>(6u);
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, AccessStorageMatrix) {
+ auto* var = b.Var<storage, mat4x4<f32>, core::Access::kRead>("v");
+ var->SetBindingPoint(0, 0);
+
+ b.ir.root_block->Append(var);
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Load(var));
+ b.Let("b", b.Load(b.Access(ty.ptr<storage, vec4<f32>, core::Access::kRead>(), var, 3_u)));
+ b.Let("c", b.LoadVectorElement(
+ b.Access(ty.ptr<storage, vec4<f32>, core::Access::kRead>(), var, 1_u), 2_u));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+ByteAddressBuffer v : register(t0);
+float4x4 v_1(uint offset) {
+ float4 v_2 = asfloat(v.Load4((offset + 0u)));
+ float4 v_3 = asfloat(v.Load4((offset + 16u)));
+ float4 v_4 = asfloat(v.Load4((offset + 32u)));
+ return float4x4(v_2, v_3, v_4, asfloat(v.Load4((offset + 48u))));
+}
+
+void foo() {
+ float4x4 a = v_1(0u);
+ float4 b = asfloat(v.Load4(48u));
+ float c = asfloat(v.Load(24u));
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, AccessStorageArray) {
+ auto* var = b.Var<storage, array<vec3<f32>, 5>, core::Access::kRead>("v");
+ var->SetBindingPoint(0, 0);
+
+ b.ir.root_block->Append(var);
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Load(var));
+ b.Let("b", b.Load(b.Access(ty.ptr<storage, vec3<f32>, core::Access::kRead>(), var, 3_u)));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+ByteAddressBuffer v : register(t0);
+typedef float3 ary_ret[5];
+ary_ret v_1(uint offset) {
+ float3 a[5] = (float3[5])0;
+ {
+ uint v_2 = 0u;
+ v_2 = 0u;
+ while(true) {
+ uint v_3 = v_2;
+ if ((v_3 >= 5u)) {
+ break;
+ }
+ a[v_3] = asfloat(v.Load3((offset + (v_3 * 16u))));
+ {
+ v_2 = (v_3 + 1u);
+ }
+ continue;
+ }
+ }
+ float3 v_4[5] = a;
+ return v_4;
+}
+
+void foo() {
+ float3 a[5] = v_1(0u);
+ float3 b = asfloat(v.Load3(48u));
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, AccessStorageStruct) {
+ auto* SB = ty.Struct(mod.symbols.New("SB"),
+ {
+ {mod.symbols.New("a"), ty.i32(), core::type::StructMemberAttributes{}},
+ {mod.symbols.New("b"), ty.f32(), core::type::StructMemberAttributes{}},
+ });
+
+ auto* var = b.Var("v", storage, SB, core::Access::kRead);
+ var->SetBindingPoint(0, 0);
+
+ b.ir.root_block->Append(var);
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Load(var));
+ b.Let("b", b.Load(b.Access(ty.ptr<storage, f32, core::Access::kRead>(), var, 1_u)));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(struct SB {
+ int a;
+ float b;
+};
+
+
+ByteAddressBuffer v : register(t0);
+SB v_1(uint offset) {
+ int v_2 = asint(v.Load((offset + 0u)));
+ SB v_3 = {v_2, asfloat(v.Load((offset + 4u)))};
+ return v_3;
+}
+
+void foo() {
+ SB a = v_1(0u);
+ float b = asfloat(v.Load(4u));
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, AccessStorageNested) {
+ auto* Inner = ty.Struct(
+ mod.symbols.New("Inner"),
+ {
+ {mod.symbols.New("s"), ty.mat3x3<f32>(), core::type::StructMemberAttributes{}},
+ {mod.symbols.New("t"), ty.array<vec3<f32>, 5>(), core::type::StructMemberAttributes{}},
+ });
+ auto* Outer =
+ ty.Struct(mod.symbols.New("Outer"),
+ {
+ {mod.symbols.New("x"), ty.f32(), core::type::StructMemberAttributes{}},
+ {mod.symbols.New("y"), Inner, core::type::StructMemberAttributes{}},
+ });
+
+ auto* SB = ty.Struct(mod.symbols.New("SB"),
+ {
+ {mod.symbols.New("a"), ty.i32(), core::type::StructMemberAttributes{}},
+ {mod.symbols.New("b"), Outer, core::type::StructMemberAttributes{}},
+ });
+
+ auto* var = b.Var("v", storage, SB, core::Access::kRead);
+ var->SetBindingPoint(0, 0);
+
+ b.ir.root_block->Append(var);
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Load(var));
+ b.Let("b", b.LoadVectorElement(b.Access(ty.ptr<storage, vec3<f32>, core::Access::kRead>(),
+ var, 1_u, 1_u, 1_u, 3_u),
+ 2_u));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(struct Inner {
+ float3x3 s;
+ float3 t[5];
+};
+
+struct Outer {
+ float x;
+ Inner y;
+};
+
+struct SB {
+ int a;
+ Outer b;
+};
+
+
+ByteAddressBuffer v : register(t0);
+typedef float3 ary_ret[5];
+ary_ret v_1(uint offset) {
+ float3 a[5] = (float3[5])0;
+ {
+ uint v_2 = 0u;
+ v_2 = 0u;
+ while(true) {
+ uint v_3 = v_2;
+ if ((v_3 >= 5u)) {
+ break;
+ }
+ a[v_3] = asfloat(v.Load3((offset + (v_3 * 16u))));
+ {
+ v_2 = (v_3 + 1u);
+ }
+ continue;
+ }
+ }
+ float3 v_4[5] = a;
+ return v_4;
+}
+
+float3x3 v_5(uint offset) {
+ float3 v_6 = asfloat(v.Load3((offset + 0u)));
+ float3 v_7 = asfloat(v.Load3((offset + 16u)));
+ return float3x3(v_6, v_7, asfloat(v.Load3((offset + 32u))));
+}
+
+Inner v_8(uint offset) {
+ float3x3 v_9 = v_5((offset + 0u));
+ float3 v_10[5] = v_1((offset + 48u));
+ Inner v_11 = {v_9, v_10};
+ return v_11;
+}
+
+Outer v_12(uint offset) {
+ float v_13 = asfloat(v.Load((offset + 0u)));
+ Inner v_14 = v_8((offset + 16u));
+ Outer v_15 = {v_13, v_14};
+ return v_15;
+}
+
+SB v_16(uint offset) {
+ int v_17 = asint(v.Load((offset + 0u)));
+ Outer v_18 = v_12((offset + 16u));
+ SB v_19 = {v_17, v_18};
+ return v_19;
+}
+
+void foo() {
+ SB a = v_16(0u);
+ float b = asfloat(v.Load(136u));
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, DISABLED_AccessStorageStoreVector) {
auto* var = b.Var<storage, vec4<f32>, core::Access::kReadWrite>("v");
var->SetBindingPoint(0, 0);
diff --git a/src/tint/lang/hlsl/writer/printer/printer.cc b/src/tint/lang/hlsl/writer/printer/printer.cc
index 116a7a1..97b3600 100644
--- a/src/tint/lang/hlsl/writer/printer/printer.cc
+++ b/src/tint/lang/hlsl/writer/printer/printer.cc
@@ -655,8 +655,25 @@
}
void EmitHlslMemberBuiltinCall(StringStream& out, const hlsl::ir::MemberBuiltinCall* c) {
+ BuiltinFn fn = c->Func();
+ std::string suffix = "";
+ if (fn == BuiltinFn::kLoadF16) {
+ fn = BuiltinFn::kLoad;
+ suffix = "<float16_t>";
+ } else if (fn == BuiltinFn::kLoad2F16) {
+ fn = BuiltinFn::kLoad2;
+ suffix = "<vector<float16_t, 2>>";
+ } else if (fn == BuiltinFn::kLoad3F16) {
+ fn = BuiltinFn::kLoad3;
+ suffix = "<vector<float16_t, 3>>";
+ } else if (fn == BuiltinFn::kLoad4F16) {
+ fn = BuiltinFn::kLoad4;
+ suffix = "<vector<float16_t, 4>>";
+ }
+
EmitValue(out, c->Object());
- out << "." << c->Func() << "(";
+ out << "." << fn << suffix << "(";
+
bool needs_comma = false;
for (const auto* arg : c->Args()) {
if (needs_comma) {
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_memory_access.cc b/src/tint/lang/hlsl/writer/raise/decompose_memory_access.cc
index 0a2609b..8476ff3 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_memory_access.cc
+++ b/src/tint/lang/hlsl/writer/raise/decompose_memory_access.cc
@@ -27,6 +27,8 @@
#include "src/tint/lang/hlsl/writer/raise/decompose_memory_access.h"
+#include <utility>
+
#include "src/tint/lang/core/ir/builder.h"
#include "src/tint/lang/core/ir/validator.h"
#include "src/tint/lang/hlsl/builtin_fn.h"
@@ -51,6 +53,10 @@
/// The type manager.
core::type::Manager& ty{ir.Types()};
+ using VarTypePair = std::pair<core::ir::Var*, const core::type::Type*>;
+ /// Maps a struct to the load function
+ Hashmap<VarTypePair, core::ir::Function*, 2> var_and_type_to_load_fn_{};
+
/// Process the module.
void Process() {
Vector<core::ir::Var*, 4> var_worklist;
@@ -75,80 +81,383 @@
}
for (auto* var : var_worklist) {
- auto* var_ty = var->Result(0)->Type()->As<core::type::Pointer>();
-
- core::type::Type* buf_type =
- ty.Get<hlsl::type::ByteAddressBuffer>(var_ty->StoreType(), var_ty->Access());
-
- // Swap the result type of the `var` to the new HLSL result type
auto* result = var->Result(0);
- result->SetType(buf_type);
// Find all the usages of the `var` which is loading or storing.
Vector<core::ir::Instruction*, 4> usage_worklist;
for (auto& usage : result->Usages()) {
Switch(
- usage->instruction, //
+ usage->instruction,
[&](core::ir::LoadVectorElement* lve) { usage_worklist.Push(lve); },
- [&](core::ir::StoreVectorElement* sve) { usage_worklist.Push(sve); }, //
-
- [&](core::ir::Store* st) { usage_worklist.Push(st); }, //
- [&](core::ir::Load* ld) { usage_worklist.Push(ld); } //
- );
+ [&](core::ir::StoreVectorElement* sve) { usage_worklist.Push(sve); },
+ [&](core::ir::Store* st) { usage_worklist.Push(st); },
+ [&](core::ir::Load* ld) { usage_worklist.Push(ld); },
+ [&](core::ir::Access* a) { usage_worklist.Push(a); });
}
+ auto* var_ty = result->Type()->As<core::type::Pointer>();
for (auto* inst : usage_worklist) {
+ // Load instructions can be destroyed by the replacing access function
+ if (!inst->Alive()) {
+ continue;
+ }
+
Switch(
- inst, //
- [&](core::ir::LoadVectorElement* lve) {
- // Converts to:
- //
- // %1:u32 = v.Load 0u
- // %b:f32 = bitcast %1
+ inst,
+ [&](core::ir::LoadVectorElement* l) { LoadVectorElement(l, var, var_ty); },
+ [&](core::ir::StoreVectorElement* s) { StoreVectorElement(s, var, var_ty); },
+ [&](core::ir::Store* s) { Store(s); }, //
+ [&](core::ir::Load* l) { Load(l, var); }, //
+ [&](core::ir::Access* a) { Access(a, var, var_ty); }, //
+ TINT_ICE_ON_NO_MATCH);
+ }
- auto* idx_value = lve->Index()->As<core::ir::Constant>();
- TINT_ASSERT(idx_value);
+ // Swap the result type of the `var` to the new HLSL result type
+ result->SetType(ty.Get<hlsl::type::ByteAddressBuffer>(var_ty->Access()));
+ }
+ }
- uint32_t pos = idx_value->Value()->ValueAs<uint32_t>() *
- var_ty->StoreType()->DeepestElement()->Size();
+ uint32_t CalculateVectorIndex(core::ir::Value* v, const core::type::Type* store_ty) {
+ auto* idx_value = v->As<core::ir::Constant>();
- auto* builtin = b.MemberCall<hlsl::ir::MemberBuiltinCall>(
- ty.u32(), BuiltinFn::kLoad, var, u32(pos));
+ // TODO(dsinclair): Handle non-constant vector indices.
+ TINT_ASSERT(idx_value);
- auto* cast = b.Bitcast(lve->Result(0)->Type(), builtin->Result(0));
- lve->Result(0)->ReplaceAllUsesWith(cast->Result(0));
+ return idx_value->Value()->ValueAs<uint32_t>() * store_ty->DeepestElement()->Size();
+ }
- builtin->InsertBefore(lve);
- cast->InsertBefore(lve);
- lve->Destroy();
- },
- [&](core::ir::StoreVectorElement* sve) {
- // Converts to:
- //
- // %1 = <sve->Value()>
- // %2:u32 = bitcast %1
- // %3:void = v.Store 0u, %2
+ // Creates the appropriate load instructions for the given result type.
+ core::ir::Call* MakeLoad(core::ir::Instruction* inst,
+ core::ir::Var* var,
+ const core::type::Type* result_ty,
+ core::ir::Value* offset) {
+ if (result_ty->is_numeric_scalar_or_vector()) {
+ return MakeScalarOrVectorLoad(var, result_ty, offset);
+ }
- auto* idx_value = sve->Index()->As<core::ir::Constant>();
- TINT_ASSERT(idx_value);
+ return tint::Switch(
+ result_ty, //
+ [&](const core::type::Struct* s) {
+ auto* fn = GetLoadFunctionFor(inst, var, s);
+ return b.Call(fn, offset);
+ },
+ [&](const core::type::Matrix* m) {
+ auto* fn = GetLoadFunctionFor(inst, var, m);
+ return b.Call(fn, offset);
+ }, //
+ [&](const core::type::Array* a) {
+ auto* fn = GetLoadFunctionFor(inst, var, a);
+ return b.Call(fn, offset);
+ }, //
+ TINT_ICE_ON_NO_MATCH);
+ }
- uint32_t pos = idx_value->Value()->ValueAs<uint32_t>() *
- var_ty->StoreType()->DeepestElement()->Size();
+ // Creates a `v.Load{2,3,4} offset` call based on the provided type. The load returns a `u32` or
+ // vector of `u32` and then a `bitcast` is done to get back to the desired type.
+ //
+ // This only works for `u32`, `i32`, `f16`, `f32` and the vector sizes of those types.
+ //
+ // The `f16` type is special in that `f16` uses a templated load in HLSL `Load<float16_t>` and
+ // returns the correct type, so there is no bitcast.
+ core::ir::Call* MakeScalarOrVectorLoad(core::ir::Var* var,
+ const core::type::Type* result_ty,
+ core::ir::Value* offset) {
+ bool is_f16 = result_ty->DeepestElement()->Is<core::type::F16>();
- auto* cast = b.Bitcast(ty.u32(), sve->Value());
- auto* builtin = b.MemberCall<hlsl::ir::MemberBuiltinCall>(
- ty.void_(), BuiltinFn::kStore, var, u32(pos), cast);
+ const core::type::Type* load_ty = ty.u32();
+ // An `f16` load returns an `f16` instead of a `u32`
+ if (is_f16) {
+ load_ty = ty.f16();
+ }
- cast->InsertBefore(sve);
- builtin->InsertBefore(sve);
- sve->Destroy();
- },
-
- [&](core::ir::Store*) {}, //
- [&](core::ir::Load*) {} //
- );
+ auto fn = is_f16 ? BuiltinFn::kLoadF16 : BuiltinFn::kLoad;
+ if (auto* v = result_ty->As<core::type::Vector>()) {
+ load_ty = ty.vec(load_ty, v->Width());
+ switch (v->Width()) {
+ case 2:
+ fn = is_f16 ? BuiltinFn::kLoad2F16 : BuiltinFn::kLoad2;
+ break;
+ case 3:
+ fn = is_f16 ? BuiltinFn::kLoad3F16 : BuiltinFn::kLoad3;
+ break;
+ case 4:
+ fn = is_f16 ? BuiltinFn::kLoad4F16 : BuiltinFn::kLoad4;
+ break;
+ default:
+ TINT_UNREACHABLE();
}
}
+
+ auto* builtin = b.MemberCall<hlsl::ir::MemberBuiltinCall>(load_ty, fn, var, offset);
+ core::ir::Call* res = nullptr;
+
+ // Do not bitcast the `f16` conversions as they need to be a templated Load instruction
+ if (is_f16) {
+ res = builtin;
+ } else {
+ res = b.Bitcast(result_ty, builtin->Result(0));
+ }
+ return res;
+ }
+
+ // Creates a load function for the given `var` and `struct` combination. Essentially creates a
+ // function similar to:
+ //
+ // fn custom_load_S(offset: u32) {
+ // let a = <load S member 0>(offset + member 0 offset);
+ // let b = <load S member 1>(offset + member 1 offset);
+ // ...
+ // let z = <load S member last>(offset + member last offset);
+ // return S(a, b, ..., z);
+ // }
+ core::ir::Function* GetLoadFunctionFor(core::ir::Instruction* inst,
+ core::ir::Var* var,
+ const core::type::Struct* s) {
+ return var_and_type_to_load_fn_.GetOrAdd(VarTypePair{var, s}, [&] {
+ auto* p = b.FunctionParam("offset", ty.u32());
+ auto* fn = b.Function(s);
+ fn->SetParams({p});
+
+ b.Append(fn->Block(), [&] {
+ Vector<core::ir::Value*, 4> values;
+ for (const auto* mem : s->Members()) {
+ values.Push(MakeLoad(inst, var, mem->Type(),
+ b.Add<u32>(p, u32(mem->Offset()))->Result(0))
+ ->Result(0));
+ }
+
+ b.Return(fn, b.Construct(s, values));
+ });
+
+ return fn;
+ });
+ }
+
+ // Creates a load function for the given `var` and `matrix` combination. Essentially creates a
+ // function similar to:
+ //
+ // fn custom_load_M(offset: u32) {
+ // let a = <load M column 1>(offset + (1 * ColumnStride));
+ // let b = <load M column 2>(offset + (2 * ColumnStride));
+ // ...
+ // let z = <load M column last>(offset + (last * ColumnStride));
+ // return M(a, b, ... z);
+ // }
+ core::ir::Function* GetLoadFunctionFor(core::ir::Instruction* inst,
+ core::ir::Var* var,
+ const core::type::Matrix* mat) {
+ return var_and_type_to_load_fn_.GetOrAdd(VarTypePair{var, mat}, [&] {
+ auto* p = b.FunctionParam("offset", ty.u32());
+ auto* fn = b.Function(mat);
+ fn->SetParams({p});
+
+ b.Append(fn->Block(), [&] {
+ Vector<core::ir::Value*, 4> values;
+ for (size_t i = 0; i < mat->columns(); ++i) {
+ auto* add = b.Add<u32>(p, u32(i * mat->ColumnStride()));
+ auto* load = MakeLoad(inst, var, mat->ColumnType(), add->Result(0));
+ values.Push(load->Result(0));
+ }
+
+ b.Return(fn, b.Construct(mat, values));
+ });
+
+ return fn;
+ });
+ }
+
+ // Creates a load function for the given `var` and `array` combination. Essentially creates a
+ // function similar to:
+ //
+ // fn custom_load_A(offset: u32) {
+ // A a = A();
+ // u32 i = 0;
+ // loop {
+ // if (i >= A length) {
+ // break;
+ // }
+ // a[i] = <load array type>(offset + (i * A->Stride()));
+ // i = i + 1;
+ // }
+ // return a;
+ // }
+ core::ir::Function* GetLoadFunctionFor(core::ir::Instruction* inst,
+ core::ir::Var* var,
+ const core::type::Array* arr) {
+ return var_and_type_to_load_fn_.GetOrAdd(VarTypePair{var, arr}, [&] {
+ auto* p = b.FunctionParam("offset", ty.u32());
+ auto* fn = b.Function(arr);
+ fn->SetParams({p});
+
+ b.Append(fn->Block(), [&] {
+ auto* result_arr = b.Var<function>("a", b.Zero(arr));
+
+ auto* count = arr->Count()->As<core::type::ConstantArrayCount>();
+ TINT_ASSERT(count);
+
+ b.LoopRange(ty, 0_u, u32(count->value), 1_u, [&](core::ir::Value* idx) {
+ auto* access = b.Access(ty.ptr<function>(arr->ElemType()), result_arr, idx);
+ auto* stride = b.Multiply<u32>(idx, u32(arr->Stride()));
+ auto* byte_offset = b.Add<u32>(p, stride);
+ b.Store(access, MakeLoad(inst, var, arr->ElemType(), byte_offset->Result(0)));
+ });
+
+ b.Return(fn, b.Load(result_arr));
+ });
+
+ return fn;
+ });
+ }
+
+ void Access(core::ir::Access* a, core::ir::Var* var, const core::type::Pointer*) {
+ const core::type::Type* obj = a->Object()->Type();
+ auto* view = obj->As<core::type::MemoryView>();
+ TINT_ASSERT(view);
+
+ obj = view->StoreType();
+
+ uint32_t byte_offset = 0;
+ for (auto* idx_value : a->Indices()) {
+ auto* cnst = idx_value->As<core::ir::Constant>();
+
+ // TODO(dsinclair): Handle non-constant accessors where the indices are dynamic
+ TINT_ASSERT(cnst);
+
+ uint32_t idx = cnst->Value()->ValueAs<uint32_t>();
+ tint::Switch(
+ obj, //
+ [&](const core::type::Vector* v) {
+ byte_offset += v->type()->Size() * idx;
+ obj = v->type();
+ },
+ [&](const core::type::Matrix* m) {
+ byte_offset += m->type()->Size() * m->rows() * idx;
+ obj = m->ColumnType();
+ },
+ [&](const core::type::Array* ary) {
+ byte_offset += ary->Stride() * idx;
+ obj = ary->ElemType();
+ },
+ [&](const core::type::Struct* s) {
+ auto* mem = s->Members()[idx];
+ byte_offset += mem->Offset();
+ obj = mem->Type();
+ },
+ TINT_ICE_ON_NO_MATCH);
+ }
+
+ auto insert_load = [&](core::ir::Instruction* inst, uint32_t offset) {
+ b.InsertBefore(inst, [&] {
+ auto* call = MakeLoad(inst, var, inst->Result(0)->Type(), b.Value(u32(offset)));
+ inst->Result(0)->ReplaceAllUsesWith(call->Result(0));
+ });
+ inst->Destroy();
+ };
+
+ // Copy the usages into a vector so we can remove items from the hashset.
+ auto usages = a->Result(0)->Usages().Vector();
+ for (auto& usage : usages) {
+ tint::Switch(
+ usage.instruction, //
+ [&](core::ir::Let*) {
+ // TODO(dsinclair): handle let
+ },
+ [&](core::ir::Access*) {
+ // TODO(dsinclair): Handle access
+ },
+
+ [&](core::ir::LoadVectorElement* lve) {
+ a->Result(0)->RemoveUsage(usage);
+
+ byte_offset += CalculateVectorIndex(lve->Index(), obj);
+ insert_load(lve, byte_offset);
+ }, //
+ [&](core::ir::Load* ld) {
+ a->Result(0)->RemoveUsage(usage);
+ insert_load(ld, byte_offset);
+ },
+
+ [&](core::ir::UserCall*) {
+ if (a->Result(0)->Type()->Is<core::type::Pointer>()) {
+ // TODO(dsinclair): Passing a pointer into a function, re-write.
+
+ // Create a new call
+ // Copy the args, skip the index which is the usage.index
+
+ } else {
+ TINT_UNREACHABLE();
+ }
+ }, //
+
+ [&](core::ir::StoreVectorElement*) {
+ // TODO(dsinclair): Handle stor vector elements
+ }, //
+ [&](core::ir::Store*) {
+ // TODO(dsinclair): Handle store
+ }, //
+
+ TINT_ICE_ON_NO_MATCH);
+ }
+
+ a->Destroy();
+ }
+
+ void Store(core::ir::Store*) {
+ // TODO(dsinclair): Handle store
+ }
+
+ // This should _only_ be handling a `var` parameter as any `access` parameters would have been
+ // replaced by the `access` being converted.
+ void Load(core::ir::Load* ld, core::ir::Var* var) {
+ auto* result = ld->From()->As<core::ir::InstructionResult>();
+ TINT_ASSERT(result);
+
+ auto* inst = result->Instruction()->As<core::ir::Var>();
+ TINT_ASSERT(inst);
+
+ const core::type::Type* result_ty = inst->Result(0)->Type()->UnwrapPtr();
+
+ b.InsertBefore(ld, [&] {
+ auto* call = MakeLoad(ld, var, result_ty, b.Value(0_u));
+ ld->Result(0)->ReplaceAllUsesWith(call->Result(0));
+ });
+ ld->Destroy();
+ }
+
+ // Converts to:
+ //
+ // %1:u32 = v.Load 0u
+ // %b:f32 = bitcast %1
+ void LoadVectorElement(core::ir::LoadVectorElement* lve,
+ core::ir::Var* var,
+ const core::type::Pointer* var_ty) {
+ uint32_t pos = CalculateVectorIndex(lve->Index(), var_ty->StoreType());
+
+ b.InsertBefore(lve, [&] {
+ auto* result = MakeScalarOrVectorLoad(var, lve->Result(0)->Type(), b.Value(u32(pos)));
+ lve->Result(0)->ReplaceAllUsesWith(result->Result(0));
+ });
+
+ lve->Destroy();
+ }
+
+ // Converts to:
+ //
+ // %1 = <sve->Value()>
+ // %2:u32 = bitcast %1
+ // %3:void = v.Store 0u, %2
+ void StoreVectorElement(core::ir::StoreVectorElement* sve,
+ core::ir::Var* var,
+ const core::type::Pointer* var_ty) {
+ uint32_t pos = CalculateVectorIndex(sve->Index(), var_ty->StoreType());
+
+ auto* cast = b.Bitcast(ty.u32(), sve->Value());
+ auto* builtin = b.MemberCall<hlsl::ir::MemberBuiltinCall>(ty.void_(), BuiltinFn::kStore,
+ var, u32(pos), cast);
+
+ cast->InsertBefore(sve);
+ builtin->InsertBefore(sve);
+ sve->Destroy();
}
};
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_memory_access.h b/src/tint/lang/hlsl/writer/raise/decompose_memory_access.h
index 2a7b36d..d3d93f2 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_memory_access.h
+++ b/src/tint/lang/hlsl/writer/raise/decompose_memory_access.h
@@ -28,8 +28,6 @@
#ifndef SRC_TINT_LANG_HLSL_WRITER_RAISE_DECOMPOSE_MEMORY_ACCESS_H_
#define SRC_TINT_LANG_HLSL_WRITER_RAISE_DECOMPOSE_MEMORY_ACCESS_H_
-#include <string>
-
#include "src/tint/utils/result/result.h"
// Forward declarations.
@@ -39,7 +37,7 @@
namespace tint::hlsl::writer::raise {
-/// DecomposeMemoryAccess is a transform used to replace storage and uniform buffer accesses with a
+/// DecomposeMemoryAccess is a transform used to replace storage buffer accesses with a
/// combination of load, store or atomic functions on primitive types.
///
/// @param module the module to transform
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_memory_access_test.cc b/src/tint/lang/hlsl/writer/raise/decompose_memory_access_test.cc
index 940976d..da0b618 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_memory_access_test.cc
+++ b/src/tint/lang/hlsl/writer/raise/decompose_memory_access_test.cc
@@ -93,11 +93,11 @@
}
}
)";
- EXPECT_EQ(src, str());
+ ASSERT_EQ(src, str());
auto* expect = R"(
$B1: { # root
- %v:hlsl.byte_address_buffer<vec4<f32>, read> = var
+ %v:hlsl.byte_address_buffer<read> = var
}
%foo = @fragment func():void {
@@ -125,8 +125,8 @@
TEST_F(HlslWriterDecomposeMemoryAccessTest, VectorStore) {
auto* var = b.Var<storage, vec4<f32>, core::Access::kReadWrite>("v");
-
b.ir.root_block->Append(var);
+
auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
b.Append(func->Block(), [&] {
b.StoreVectorElement(var, 0_u, 2_f);
@@ -151,11 +151,11 @@
}
}
)";
- EXPECT_EQ(src, str());
+ ASSERT_EQ(src, str());
auto* expect = R"(
$B1: { # root
- %v:hlsl.byte_address_buffer<vec4<f32>, read_write> = var
+ %v:hlsl.byte_address_buffer<read_write> = var
}
%foo = @fragment func():void {
@@ -177,5 +177,907 @@
EXPECT_EQ(expect, str());
}
+TEST_F(HlslWriterDecomposeMemoryAccessTest, DISABLED_AccessChainFromUnnamedAccessChain) {
+ auto* Inner =
+ ty.Struct(mod.symbols.New("Inner"),
+ {
+ {mod.symbols.New("c"), ty.f32(), core::type::StructMemberAttributes{}},
+ });
+ auto* sb = ty.Struct(mod.symbols.New("SB"),
+ {
+ {mod.symbols.New("a"), ty.i32(), core::type::StructMemberAttributes{}},
+ {mod.symbols.New("b"), Inner, core::type::StructMemberAttributes{}},
+ });
+
+ auto* var = b.Var("v", storage, sb, core::Access::kReadWrite);
+ b.ir.root_block->Append(var);
+
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ auto* x = b.Access(ty.ptr(storage, sb, core::Access::kReadWrite), var);
+ auto* y = b.Access(ty.ptr(storage, Inner, core::Access::kReadWrite), x->Result(0), 1_u);
+ b.Let("b", b.Load(b.Access(ty.ptr(storage, ty.f32(), core::Access::kReadWrite),
+ y->Result(0), 0_u)));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+Inner = struct @align(4) {
+ c:f32 @offset(0)
+}
+
+SB = struct @align(4) {
+ a:i32 @offset(0)
+ b:Inner @offset(4)
+}
+
+$B1: { # root
+ %v:ptr<storage, SB, read_write> = var
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:ptr<storage, SB, read_write> = access %v
+ %4:ptr<storage, Inner, read_write> = access %3, 1u
+ %5:ptr<storage, f32, read_write> = access %4, 0u
+ %6:f32 = load %5
+ %b:f32 = let %6
+ ret
+ }
+}
+)";
+ ASSERT_EQ(src, str());
+
+ auto* expect = R"(
+SB = struct @align(16) {
+ a:i32 @offset(0)
+ b:vec3<f32> @offset(16)
+}
+
+$B1: { # root
+ %v:hlsl.byte_address_buffer<read_write> = var
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:vec3<u32> = %v.Load3 16u
+ %a:vec3<f32> = bitcast %3
+ %b:f32 = %a 1u
+ ret
+ }
+}
+)";
+
+ Run(DecomposeMemoryAccess);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterDecomposeMemoryAccessTest, DISABLED_AccessChainFromLetAccessChain) {
+ auto* Inner =
+ ty.Struct(mod.symbols.New("Inner"),
+ {
+ {mod.symbols.New("c"), ty.f32(), core::type::StructMemberAttributes{}},
+ });
+ auto* sb = ty.Struct(mod.symbols.New("SB"),
+ {
+ {mod.symbols.New("a"), ty.i32(), core::type::StructMemberAttributes{}},
+ {mod.symbols.New("b"), Inner, core::type::StructMemberAttributes{}},
+ });
+
+ auto* var = b.Var("v", storage, sb, core::Access::kReadWrite);
+ b.ir.root_block->Append(var);
+
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ auto* a = b.Let("a", b.Access(ty.ptr(storage, Inner, core::Access::kReadWrite), var, 1_u));
+ b.Let("b", b.Load(b.Access(ty.ptr(storage, ty.f32(), core::Access::kReadWrite),
+ a->Result(0), 0_u)));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+Inner = struct @align(4) {
+ c:f32 @offset(0)
+}
+
+SB = struct @align(4) {
+ a:i32 @offset(0)
+ b:Inner @offset(4)
+}
+
+$B1: { # root
+ %v:ptr<storage, SB, read_write> = var
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:ptr<storage, Inner, read_write> = access %v, 1u
+ %a:ptr<storage, Inner, read_write> = let %3
+ %5:ptr<storage, f32, read_write> = access %a, 0u
+ %6:f32 = load %5
+ %b:f32 = let %6
+ ret
+ }
+}
+)";
+ ASSERT_EQ(src, str());
+
+ auto* expect = R"(
+SB = struct @align(16) {
+ a:i32 @offset(0)
+ b:vec3<f32> @offset(16)
+}
+
+$B1: { # root
+ %v:hlsl.byte_address_buffer<read_write> = var
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:vec3<u32> = %v.Load3 16u
+ %a:vec3<f32> = bitcast %3
+ %b:f32 = %a 1u
+ ret
+ }
+}
+)";
+
+ Run(DecomposeMemoryAccess);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterDecomposeMemoryAccessTest, AccessRwByteAddressBuffer) {
+ auto* sb =
+ ty.Struct(mod.symbols.New("SB"),
+ {
+ {mod.symbols.New("a"), ty.i32(), core::type::StructMemberAttributes{}},
+ {mod.symbols.New("b"), ty.vec3<f32>(), core::type::StructMemberAttributes{}},
+ });
+
+ auto* var = b.Var("v", storage, sb, core::Access::kReadWrite);
+ b.ir.root_block->Append(var);
+
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Load(b.Access(ty.ptr(storage, ty.i32(), core::Access::kReadWrite), var, 0_u)));
+ b.Let("b", b.Load(b.Access(ty.ptr(storage, ty.vec3<f32>(), core::Access::kReadWrite), var,
+ 1_u)));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+SB = struct @align(16) {
+ a:i32 @offset(0)
+ b:vec3<f32> @offset(16)
+}
+
+$B1: { # root
+ %v:ptr<storage, SB, read_write> = var
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:ptr<storage, i32, read_write> = access %v, 0u
+ %4:i32 = load %3
+ %a:i32 = let %4
+ %6:ptr<storage, vec3<f32>, read_write> = access %v, 1u
+ %7:vec3<f32> = load %6
+ %b:vec3<f32> = let %7
+ ret
+ }
+}
+)";
+ ASSERT_EQ(src, str());
+
+ auto* expect = R"(
+SB = struct @align(16) {
+ a:i32 @offset(0)
+ b:vec3<f32> @offset(16)
+}
+
+$B1: { # root
+ %v:hlsl.byte_address_buffer<read_write> = var
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:u32 = %v.Load 0u
+ %4:i32 = bitcast %3
+ %a:i32 = let %4
+ %6:vec3<u32> = %v.Load3 16u
+ %7:vec3<f32> = bitcast %6
+ %b:vec3<f32> = let %7
+ ret
+ }
+}
+)";
+
+ Run(DecomposeMemoryAccess);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterDecomposeMemoryAccessTest, AccessByteAddressBuffer) {
+ auto* sb = ty.Struct(mod.symbols.New("SB"),
+ {
+ {mod.symbols.New("a"), ty.i32(), core::type::StructMemberAttributes{}},
+ });
+ auto* var = b.Var("v", storage, sb, core::Access::kRead);
+ b.ir.root_block->Append(var);
+
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Load(b.Access(ty.ptr(storage, ty.i32(), core::Access::kRead), var, 0_u)));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+SB = struct @align(4) {
+ a:i32 @offset(0)
+}
+
+$B1: { # root
+ %v:ptr<storage, SB, read> = var
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:ptr<storage, i32, read> = access %v, 0u
+ %4:i32 = load %3
+ %a:i32 = let %4
+ ret
+ }
+}
+)";
+ ASSERT_EQ(src, str());
+
+ auto* expect = R"(
+SB = struct @align(4) {
+ a:i32 @offset(0)
+}
+
+$B1: { # root
+ %v:hlsl.byte_address_buffer<read> = var
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:u32 = %v.Load 0u
+ %4:i32 = bitcast %3
+ %a:i32 = let %4
+ ret
+ }
+}
+)";
+
+ Run(DecomposeMemoryAccess);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterDecomposeMemoryAccessTest, AccessStorageVector) {
+ auto* var = b.Var<storage, vec4<f32>, core::Access::kRead>("v");
+ var->SetBindingPoint(0, 0);
+
+ b.ir.root_block->Append(var);
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Load(var));
+ b.Let("b", b.LoadVectorElement(var, 0_u));
+ b.Let("c", b.LoadVectorElement(var, 1_u));
+ b.Let("d", b.LoadVectorElement(var, 2_u));
+ b.Let("e", b.LoadVectorElement(var, 3_u));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+$B1: { # root
+ %v:ptr<storage, vec4<f32>, read> = var @binding_point(0, 0)
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:vec4<f32> = load %v
+ %a:vec4<f32> = let %3
+ %5:f32 = load_vector_element %v, 0u
+ %b:f32 = let %5
+ %7:f32 = load_vector_element %v, 1u
+ %c:f32 = let %7
+ %9:f32 = load_vector_element %v, 2u
+ %d:f32 = let %9
+ %11:f32 = load_vector_element %v, 3u
+ %e:f32 = let %11
+ ret
+ }
+}
+)";
+ ASSERT_EQ(src, str());
+
+ auto* expect = R"(
+$B1: { # root
+ %v:hlsl.byte_address_buffer<read> = var @binding_point(0, 0)
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:vec4<u32> = %v.Load4 0u
+ %4:vec4<f32> = bitcast %3
+ %a:vec4<f32> = let %4
+ %6:u32 = %v.Load 0u
+ %7:f32 = bitcast %6
+ %b:f32 = let %7
+ %9:u32 = %v.Load 4u
+ %10:f32 = bitcast %9
+ %c:f32 = let %10
+ %12:u32 = %v.Load 8u
+ %13:f32 = bitcast %12
+ %d:f32 = let %13
+ %15:u32 = %v.Load 12u
+ %16:f32 = bitcast %15
+ %e:f32 = let %16
+ ret
+ }
+}
+)";
+ Run(DecomposeMemoryAccess);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterDecomposeMemoryAccessTest, AccessStorageVectorF16) {
+ auto* var = b.Var<storage, vec4<f16>, core::Access::kRead>("v");
+ var->SetBindingPoint(0, 0);
+
+ b.ir.root_block->Append(var);
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Load(var));
+ b.Let("b", b.LoadVectorElement(var, 0_u));
+ b.Let("c", b.LoadVectorElement(var, 1_u));
+ b.Let("d", b.LoadVectorElement(var, 2_u));
+ b.Let("e", b.LoadVectorElement(var, 3_u));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+$B1: { # root
+ %v:ptr<storage, vec4<f16>, read> = var @binding_point(0, 0)
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:vec4<f16> = load %v
+ %a:vec4<f16> = let %3
+ %5:f16 = load_vector_element %v, 0u
+ %b:f16 = let %5
+ %7:f16 = load_vector_element %v, 1u
+ %c:f16 = let %7
+ %9:f16 = load_vector_element %v, 2u
+ %d:f16 = let %9
+ %11:f16 = load_vector_element %v, 3u
+ %e:f16 = let %11
+ ret
+ }
+}
+)";
+ ASSERT_EQ(src, str());
+
+ auto* expect = R"(
+$B1: { # root
+ %v:hlsl.byte_address_buffer<read> = var @binding_point(0, 0)
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:vec4<f16> = %v.Load4F16 0u
+ %a:vec4<f16> = let %3
+ %5:f16 = %v.LoadF16 0u
+ %b:f16 = let %5
+ %7:f16 = %v.LoadF16 2u
+ %c:f16 = let %7
+ %9:f16 = %v.LoadF16 4u
+ %d:f16 = let %9
+ %11:f16 = %v.LoadF16 6u
+ %e:f16 = let %11
+ ret
+ }
+}
+)";
+ Run(DecomposeMemoryAccess);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterDecomposeMemoryAccessTest, AccessStorageMatrix) {
+ auto* var = b.Var<storage, mat4x4<f32>, core::Access::kRead>("v");
+ var->SetBindingPoint(0, 0);
+
+ b.ir.root_block->Append(var);
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Load(var));
+ b.Let("b", b.Load(b.Access(ty.ptr<storage, vec4<f32>, core::Access::kRead>(), var, 3_u)));
+ b.Let("c", b.LoadVectorElement(
+ b.Access(ty.ptr<storage, vec4<f32>, core::Access::kRead>(), var, 1_u), 2_u));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+$B1: { # root
+ %v:ptr<storage, mat4x4<f32>, read> = var @binding_point(0, 0)
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:mat4x4<f32> = load %v
+ %a:mat4x4<f32> = let %3
+ %5:ptr<storage, vec4<f32>, read> = access %v, 3u
+ %6:vec4<f32> = load %5
+ %b:vec4<f32> = let %6
+ %8:ptr<storage, vec4<f32>, read> = access %v, 1u
+ %9:f32 = load_vector_element %8, 2u
+ %c:f32 = let %9
+ ret
+ }
+}
+)";
+ ASSERT_EQ(src, str());
+
+ auto* expect = R"(
+$B1: { # root
+ %v:hlsl.byte_address_buffer<read> = var @binding_point(0, 0)
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:mat4x4<f32> = call %4, 0u
+ %a:mat4x4<f32> = let %3
+ %6:vec4<u32> = %v.Load4 48u
+ %7:vec4<f32> = bitcast %6
+ %b:vec4<f32> = let %7
+ %9:u32 = %v.Load 24u
+ %10:f32 = bitcast %9
+ %c:f32 = let %10
+ ret
+ }
+}
+%4 = func(%offset:u32):mat4x4<f32> {
+ $B3: {
+ %13:u32 = add %offset, 0u
+ %14:vec4<u32> = %v.Load4 %13
+ %15:vec4<f32> = bitcast %14
+ %16:u32 = add %offset, 16u
+ %17:vec4<u32> = %v.Load4 %16
+ %18:vec4<f32> = bitcast %17
+ %19:u32 = add %offset, 32u
+ %20:vec4<u32> = %v.Load4 %19
+ %21:vec4<f32> = bitcast %20
+ %22:u32 = add %offset, 48u
+ %23:vec4<u32> = %v.Load4 %22
+ %24:vec4<f32> = bitcast %23
+ %25:mat4x4<f32> = construct %15, %18, %21, %24
+ ret %25
+ }
+}
+)";
+ Run(DecomposeMemoryAccess);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterDecomposeMemoryAccessTest, AccessStorageArray) {
+ auto* var = b.Var<storage, array<vec3<f32>, 5>, core::Access::kRead>("v");
+ var->SetBindingPoint(0, 0);
+
+ b.ir.root_block->Append(var);
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Load(var));
+ b.Let("b", b.Load(b.Access(ty.ptr<storage, vec3<f32>, core::Access::kRead>(), var, 3_u)));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+$B1: { # root
+ %v:ptr<storage, array<vec3<f32>, 5>, read> = var @binding_point(0, 0)
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:array<vec3<f32>, 5> = load %v
+ %a:array<vec3<f32>, 5> = let %3
+ %5:ptr<storage, vec3<f32>, read> = access %v, 3u
+ %6:vec3<f32> = load %5
+ %b:vec3<f32> = let %6
+ ret
+ }
+}
+)";
+ ASSERT_EQ(src, str());
+
+ auto* expect = R"(
+$B1: { # root
+ %v:hlsl.byte_address_buffer<read> = var @binding_point(0, 0)
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:array<vec3<f32>, 5> = call %4, 0u
+ %a:array<vec3<f32>, 5> = let %3
+ %6:vec3<u32> = %v.Load3 48u
+ %7:vec3<f32> = bitcast %6
+ %b:vec3<f32> = let %7
+ ret
+ }
+}
+%4 = func(%offset:u32):array<vec3<f32>, 5> {
+ $B3: {
+ %a_1:ptr<function, array<vec3<f32>, 5>, read_write> = var, array<vec3<f32>, 5>(vec3<f32>(0.0f)) # %a_1: 'a'
+ loop [i: $B4, b: $B5, c: $B6] { # loop_1
+ $B4: { # initializer
+ next_iteration 0u # -> $B5
+ }
+ $B5 (%idx:u32): { # body
+ %12:bool = gte %idx, 5u
+ if %12 [t: $B7] { # if_1
+ $B7: { # true
+ exit_loop # loop_1
+ }
+ }
+ %13:ptr<function, vec3<f32>, read_write> = access %a_1, %idx
+ %14:u32 = mul %idx, 16u
+ %15:u32 = add %offset, %14
+ %16:vec3<u32> = %v.Load3 %15
+ %17:vec3<f32> = bitcast %16
+ store %13, %17
+ continue # -> $B6
+ }
+ $B6: { # continuing
+ %18:u32 = add %idx, 1u
+ next_iteration %18 # -> $B5
+ }
+ }
+ %19:array<vec3<f32>, 5> = load %a_1
+ ret %19
+ }
+}
+)";
+ Run(DecomposeMemoryAccess);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterDecomposeMemoryAccessTest, AccessStorageArrayWhichCanHaveSizesOtherThenFive) {
+ auto* var = b.Var<storage, array<vec3<f32>, 42>, core::Access::kRead>("v");
+ var->SetBindingPoint(0, 0);
+
+ b.ir.root_block->Append(var);
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Load(var));
+ b.Let("b", b.Load(b.Access(ty.ptr<storage, vec3<f32>, core::Access::kRead>(), var, 3_u)));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+$B1: { # root
+ %v:ptr<storage, array<vec3<f32>, 42>, read> = var @binding_point(0, 0)
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:array<vec3<f32>, 42> = load %v
+ %a:array<vec3<f32>, 42> = let %3
+ %5:ptr<storage, vec3<f32>, read> = access %v, 3u
+ %6:vec3<f32> = load %5
+ %b:vec3<f32> = let %6
+ ret
+ }
+}
+)";
+ ASSERT_EQ(src, str());
+
+ auto* expect = R"(
+$B1: { # root
+ %v:hlsl.byte_address_buffer<read> = var @binding_point(0, 0)
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:array<vec3<f32>, 42> = call %4, 0u
+ %a:array<vec3<f32>, 42> = let %3
+ %6:vec3<u32> = %v.Load3 48u
+ %7:vec3<f32> = bitcast %6
+ %b:vec3<f32> = let %7
+ ret
+ }
+}
+%4 = func(%offset:u32):array<vec3<f32>, 42> {
+ $B3: {
+ %a_1:ptr<function, array<vec3<f32>, 42>, read_write> = var, array<vec3<f32>, 42>(vec3<f32>(0.0f)) # %a_1: 'a'
+ loop [i: $B4, b: $B5, c: $B6] { # loop_1
+ $B4: { # initializer
+ next_iteration 0u # -> $B5
+ }
+ $B5 (%idx:u32): { # body
+ %12:bool = gte %idx, 42u
+ if %12 [t: $B7] { # if_1
+ $B7: { # true
+ exit_loop # loop_1
+ }
+ }
+ %13:ptr<function, vec3<f32>, read_write> = access %a_1, %idx
+ %14:u32 = mul %idx, 16u
+ %15:u32 = add %offset, %14
+ %16:vec3<u32> = %v.Load3 %15
+ %17:vec3<f32> = bitcast %16
+ store %13, %17
+ continue # -> $B6
+ }
+ $B6: { # continuing
+ %18:u32 = add %idx, 1u
+ next_iteration %18 # -> $B5
+ }
+ }
+ %19:array<vec3<f32>, 42> = load %a_1
+ ret %19
+ }
+}
+)";
+ Run(DecomposeMemoryAccess);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterDecomposeMemoryAccessTest, AccessStorageStruct) {
+ auto* SB = ty.Struct(mod.symbols.New("SB"),
+ {
+ {mod.symbols.New("a"), ty.i32(), core::type::StructMemberAttributes{}},
+ {mod.symbols.New("b"), ty.f32(), core::type::StructMemberAttributes{}},
+ });
+
+ auto* var = b.Var("v", storage, SB, core::Access::kRead);
+ var->SetBindingPoint(0, 0);
+
+ b.ir.root_block->Append(var);
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Load(var));
+ b.Let("b", b.Load(b.Access(ty.ptr<storage, f32, core::Access::kRead>(), var, 1_u)));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+SB = struct @align(4) {
+ a:i32 @offset(0)
+ b:f32 @offset(4)
+}
+
+$B1: { # root
+ %v:ptr<storage, SB, read> = var @binding_point(0, 0)
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:SB = load %v
+ %a:SB = let %3
+ %5:ptr<storage, f32, read> = access %v, 1u
+ %6:f32 = load %5
+ %b:f32 = let %6
+ ret
+ }
+}
+)";
+ ASSERT_EQ(src, str());
+
+ auto* expect = R"(
+SB = struct @align(4) {
+ a:i32 @offset(0)
+ b:f32 @offset(4)
+}
+
+$B1: { # root
+ %v:hlsl.byte_address_buffer<read> = var @binding_point(0, 0)
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:SB = call %4, 0u
+ %a:SB = let %3
+ %6:u32 = %v.Load 4u
+ %7:f32 = bitcast %6
+ %b:f32 = let %7
+ ret
+ }
+}
+%4 = func(%offset:u32):SB {
+ $B3: {
+ %10:u32 = add %offset, 0u
+ %11:u32 = %v.Load %10
+ %12:i32 = bitcast %11
+ %13:u32 = add %offset, 4u
+ %14:u32 = %v.Load %13
+ %15:f32 = bitcast %14
+ %16:SB = construct %12, %15
+ ret %16
+ }
+}
+)";
+ Run(DecomposeMemoryAccess);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterDecomposeMemoryAccessTest, AccessStorageNested) {
+ auto* Inner = ty.Struct(
+ mod.symbols.New("Inner"),
+ {
+ {mod.symbols.New("s"), ty.mat3x3<f32>(), core::type::StructMemberAttributes{}},
+ {mod.symbols.New("t"), ty.array<vec3<f32>, 5>(), core::type::StructMemberAttributes{}},
+ });
+ auto* Outer =
+ ty.Struct(mod.symbols.New("Outer"),
+ {
+ {mod.symbols.New("x"), ty.f32(), core::type::StructMemberAttributes{}},
+ {mod.symbols.New("y"), Inner, core::type::StructMemberAttributes{}},
+ });
+
+ auto* SB = ty.Struct(mod.symbols.New("SB"),
+ {
+ {mod.symbols.New("a"), ty.i32(), core::type::StructMemberAttributes{}},
+ {mod.symbols.New("b"), Outer, core::type::StructMemberAttributes{}},
+ });
+
+ auto* var = b.Var("v", storage, SB, core::Access::kRead);
+ var->SetBindingPoint(0, 0);
+
+ b.ir.root_block->Append(var);
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Load(var));
+ b.Let("b", b.LoadVectorElement(b.Access(ty.ptr<storage, vec3<f32>, core::Access::kRead>(),
+ var, 1_u, 1_u, 1_u, 3_u),
+ 2_u));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+Inner = struct @align(16) {
+ s:mat3x3<f32> @offset(0)
+ t:array<vec3<f32>, 5> @offset(48)
+}
+
+Outer = struct @align(16) {
+ x:f32 @offset(0)
+ y:Inner @offset(16)
+}
+
+SB = struct @align(16) {
+ a:i32 @offset(0)
+ b:Outer @offset(16)
+}
+
+$B1: { # root
+ %v:ptr<storage, SB, read> = var @binding_point(0, 0)
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:SB = load %v
+ %a:SB = let %3
+ %5:ptr<storage, vec3<f32>, read> = access %v, 1u, 1u, 1u, 3u
+ %6:f32 = load_vector_element %5, 2u
+ %b:f32 = let %6
+ ret
+ }
+}
+)";
+ ASSERT_EQ(src, str());
+
+ auto* expect = R"(
+Inner = struct @align(16) {
+ s:mat3x3<f32> @offset(0)
+ t:array<vec3<f32>, 5> @offset(48)
+}
+
+Outer = struct @align(16) {
+ x:f32 @offset(0)
+ y:Inner @offset(16)
+}
+
+SB = struct @align(16) {
+ a:i32 @offset(0)
+ b:Outer @offset(16)
+}
+
+$B1: { # root
+ %v:hlsl.byte_address_buffer<read> = var @binding_point(0, 0)
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %3:SB = call %4, 0u
+ %a:SB = let %3
+ %6:u32 = %v.Load 136u
+ %7:f32 = bitcast %6
+ %b:f32 = let %7
+ ret
+ }
+}
+%4 = func(%offset:u32):SB {
+ $B3: {
+ %10:u32 = add %offset, 0u
+ %11:u32 = %v.Load %10
+ %12:i32 = bitcast %11
+ %13:u32 = add %offset, 16u
+ %14:Outer = call %15, %13
+ %16:SB = construct %12, %14
+ ret %16
+ }
+}
+%15 = func(%offset_1:u32):Outer { # %offset_1: 'offset'
+ $B4: {
+ %18:u32 = add %offset_1, 0u
+ %19:u32 = %v.Load %18
+ %20:f32 = bitcast %19
+ %21:u32 = add %offset_1, 16u
+ %22:Inner = call %23, %21
+ %24:Outer = construct %20, %22
+ ret %24
+ }
+}
+%23 = func(%offset_2:u32):Inner { # %offset_2: 'offset'
+ $B5: {
+ %26:u32 = add %offset_2, 0u
+ %27:mat3x3<f32> = call %28, %26
+ %29:u32 = add %offset_2, 48u
+ %30:array<vec3<f32>, 5> = call %31, %29
+ %32:Inner = construct %27, %30
+ ret %32
+ }
+}
+%28 = func(%offset_3:u32):mat3x3<f32> { # %offset_3: 'offset'
+ $B6: {
+ %34:u32 = add %offset_3, 0u
+ %35:vec3<u32> = %v.Load3 %34
+ %36:vec3<f32> = bitcast %35
+ %37:u32 = add %offset_3, 16u
+ %38:vec3<u32> = %v.Load3 %37
+ %39:vec3<f32> = bitcast %38
+ %40:u32 = add %offset_3, 32u
+ %41:vec3<u32> = %v.Load3 %40
+ %42:vec3<f32> = bitcast %41
+ %43:mat3x3<f32> = construct %36, %39, %42
+ ret %43
+ }
+}
+%31 = func(%offset_4:u32):array<vec3<f32>, 5> { # %offset_4: 'offset'
+ $B7: {
+ %a_1:ptr<function, array<vec3<f32>, 5>, read_write> = var, array<vec3<f32>, 5>(vec3<f32>(0.0f)) # %a_1: 'a'
+ loop [i: $B8, b: $B9, c: $B10] { # loop_1
+ $B8: { # initializer
+ next_iteration 0u # -> $B9
+ }
+ $B9 (%idx:u32): { # body
+ %47:bool = gte %idx, 5u
+ if %47 [t: $B11] { # if_1
+ $B11: { # true
+ exit_loop # loop_1
+ }
+ }
+ %48:ptr<function, vec3<f32>, read_write> = access %a_1, %idx
+ %49:u32 = mul %idx, 16u
+ %50:u32 = add %offset_4, %49
+ %51:vec3<u32> = %v.Load3 %50
+ %52:vec3<f32> = bitcast %51
+ store %48, %52
+ continue # -> $B10
+ }
+ $B10: { # continuing
+ %53:u32 = add %idx, 1u
+ next_iteration %53 # -> $B9
+ }
+ }
+ %54:array<vec3<f32>, 5> = load %a_1
+ ret %54
+ }
+}
+)";
+ Run(DecomposeMemoryAccess);
+ EXPECT_EQ(expect, str());
+}
+
} // namespace
} // namespace tint::hlsl::writer::raise
diff --git a/src/tint/lang/hlsl/writer/raise/raise.cc b/src/tint/lang/hlsl/writer/raise/raise.cc
index e5f7b80..b75f3c7 100644
--- a/src/tint/lang/hlsl/writer/raise/raise.cc
+++ b/src/tint/lang/hlsl/writer/raise/raise.cc
@@ -133,10 +133,6 @@
RUN_TRANSFORM(core::ir::transform::AddEmptyEntryPoint, module);
- RUN_TRANSFORM(core::ir::transform::DirectVariableAccess, module,
- core::ir::transform::DirectVariableAccessOptions{});
- RUN_TRANSFORM(raise::DecomposeMemoryAccess, module);
-
if (options.compiler == Options::Compiler::kFXC) {
RUN_TRANSFORM(raise::FxcPolyfill, module);
}
@@ -156,6 +152,12 @@
RUN_TRANSFORM(core::ir::transform::Robustness, module, config);
}
+
+ RUN_TRANSFORM(core::ir::transform::DirectVariableAccess, module,
+ core::ir::transform::DirectVariableAccessOptions{});
+ // DecomposeMemoryAccess must come after Robustness and DirectVariableAccess
+ RUN_TRANSFORM(raise::DecomposeMemoryAccess, module);
+
if (!options.disable_workgroup_init) {
RUN_TRANSFORM(core::ir::transform::ZeroInitWorkgroupMemory, module);
}