[hlsl] Add dynamic buffer accesses to Decompose Memory Access.
This Cl adds the ability to have dynamic buffer accesses in the
decompose memory access transform.
Bug: 349867642
Change-Id: I256116f63a07a30187529a6be4b19abaad82355e
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/196974
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/lang/hlsl/builtin_fn.cc b/src/tint/lang/hlsl/builtin_fn.cc
index 3bfa425..9fb3784 100644
--- a/src/tint/lang/hlsl/builtin_fn.cc
+++ b/src/tint/lang/hlsl/builtin_fn.cc
@@ -52,6 +52,8 @@
return "f32tof16";
case BuiltinFn::kF16Tof32:
return "f16tof32";
+ case BuiltinFn::kMul:
+ return "mul";
case BuiltinFn::kLoad:
return "Load";
case BuiltinFn::kLoad2:
@@ -76,8 +78,8 @@
return "Store3";
case BuiltinFn::kStore4:
return "Store4";
- case BuiltinFn::kMul:
- return "mul";
+ case BuiltinFn::kGetDimensions:
+ return "GetDimensions";
}
return "<unknown>";
}
diff --git a/src/tint/lang/hlsl/builtin_fn.h b/src/tint/lang/hlsl/builtin_fn.h
index ab8b68a..fb6c0b1 100644
--- a/src/tint/lang/hlsl/builtin_fn.h
+++ b/src/tint/lang/hlsl/builtin_fn.h
@@ -52,6 +52,7 @@
kAsfloat,
kF32Tof16,
kF16Tof32,
+ kMul,
kLoad,
kLoad2,
kLoad3,
@@ -64,7 +65,7 @@
kStore2,
kStore3,
kStore4,
- kMul,
+ kGetDimensions,
kNone,
};
diff --git a/src/tint/lang/hlsl/hlsl.def b/src/tint/lang/hlsl/hlsl.def
index b226e1d..627646d 100644
--- a/src/tint/lang/hlsl/hlsl.def
+++ b/src/tint/lang/hlsl/hlsl.def
@@ -100,6 +100,10 @@
fn f16tof32(u32) -> f32
fn f16tof32[N: num](vec<N, u32>) -> vec<N, f32>
+fn mul [T: f32_f16, C: num, R: num](mat<C, R, T>, vec<C, T>) -> vec<R, T>
+fn mul [T: f32_f16, C: num, R: num](vec<R, T>, mat<C, R, T>) -> vec<C, T>
+fn mul [T: f32_f16, K: num, C: num, R: num](mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T>
+
@member_function fn Load(byte_address_buffer<readable>, offset: u32) -> u32
@member_function fn Load2(byte_address_buffer<readable>, offset: u32) -> vec2<u32>
@member_function fn Load3(byte_address_buffer<readable>, offset: u32) -> vec3<u32>
@@ -115,6 +119,5 @@
@member_function fn Store3(byte_address_buffer<writable>, offset: u32, value: vec3<u32>)
@member_function fn Store4(byte_address_buffer<writable>, offset: u32, value: vec4<u32>)
-fn mul [T: f32_f16, C: num, R: num](mat<C, R, T>, vec<C, T>) -> vec<R, T>
-fn mul [T: f32_f16, C: num, R: num](vec<R, T>, mat<C, R, T>) -> vec<C, T>
-fn mul [T: f32_f16, K: num, C: num, R: num](mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T>
+@member_function fn GetDimensions[A: access](byte_address_buffer<A>, u32)
+
diff --git a/src/tint/lang/hlsl/intrinsic/data.cc b/src/tint/lang/hlsl/intrinsic/data.cc
index 8db2ea0..dc2b89e 100644
--- a/src/tint/lang/hlsl/intrinsic/data.cc
+++ b/src/tint/lang/hlsl/intrinsic/data.cc
@@ -693,10 +693,12 @@
/* [50] */ MatcherIndex(7),
/* [51] */ MatcherIndex(23),
/* [52] */ MatcherIndex(6),
- /* [53] */ MatcherIndex(25),
- /* [54] */ MatcherIndex(26),
- /* [55] */ MatcherIndex(24),
- /* [56] */ MatcherIndex(27),
+ /* [53] */ MatcherIndex(23),
+ /* [54] */ MatcherIndex(0),
+ /* [55] */ MatcherIndex(25),
+ /* [56] */ MatcherIndex(26),
+ /* [57] */ MatcherIndex(24),
+ /* [58] */ MatcherIndex(27),
};
static_assert(MatcherIndicesIndex::CanIndex(kMatcherIndices),
@@ -766,66 +768,71 @@
{
/* [12] */
/* usage */ core::ParameterUsage::kNone,
- /* matcher_indices */ MatcherIndicesIndex(37),
- },
- {
- /* [13] */
- /* usage */ core::ParameterUsage::kOffset,
- /* matcher_indices */ MatcherIndicesIndex(24),
- },
- {
- /* [14] */
- /* usage */ core::ParameterUsage::kNone,
/* matcher_indices */ MatcherIndicesIndex(0),
},
{
- /* [15] */
+ /* [13] */
/* usage */ core::ParameterUsage::kNone,
/* matcher_indices */ MatcherIndicesIndex(19),
},
{
- /* [16] */
+ /* [14] */
/* usage */ core::ParameterUsage::kNone,
/* matcher_indices */ MatcherIndicesIndex(34),
},
{
- /* [17] */
+ /* [15] */
/* usage */ core::ParameterUsage::kNone,
/* matcher_indices */ MatcherIndicesIndex(0),
},
{
- /* [18] */
+ /* [16] */
/* usage */ core::ParameterUsage::kNone,
/* matcher_indices */ MatcherIndicesIndex(8),
},
{
- /* [19] */
+ /* [17] */
/* usage */ core::ParameterUsage::kNone,
/* matcher_indices */ MatcherIndicesIndex(12),
},
{
+ /* [18] */
+ /* usage */ core::ParameterUsage::kNone,
+ /* matcher_indices */ MatcherIndicesIndex(37),
+ },
+ {
+ /* [19] */
+ /* usage */ core::ParameterUsage::kOffset,
+ /* matcher_indices */ MatcherIndicesIndex(24),
+ },
+ {
/* [20] */
/* usage */ core::ParameterUsage::kNone,
- /* matcher_indices */ MatcherIndicesIndex(3),
+ /* matcher_indices */ MatcherIndicesIndex(53),
},
{
/* [21] */
/* usage */ core::ParameterUsage::kNone,
- /* matcher_indices */ MatcherIndicesIndex(27),
+ /* matcher_indices */ MatcherIndicesIndex(24),
},
{
/* [22] */
/* usage */ core::ParameterUsage::kNone,
- /* matcher_indices */ MatcherIndicesIndex(31),
+ /* matcher_indices */ MatcherIndicesIndex(3),
},
{
/* [23] */
/* usage */ core::ParameterUsage::kNone,
- /* matcher_indices */ MatcherIndicesIndex(24),
+ /* matcher_indices */ MatcherIndicesIndex(27),
},
{
/* [24] */
/* usage */ core::ParameterUsage::kNone,
+ /* matcher_indices */ MatcherIndicesIndex(31),
+ },
+ {
+ /* [25] */
+ /* usage */ core::ParameterUsage::kNone,
/* matcher_indices */ MatcherIndicesIndex(28),
},
};
@@ -837,7 +844,7 @@
{
/* [0] */
/* name */ "T",
- /* matcher_indices */ MatcherIndicesIndex(56),
+ /* matcher_indices */ MatcherIndicesIndex(58),
/* kind */ TemplateInfo::Kind::kType,
},
{
@@ -861,7 +868,7 @@
{
/* [4] */
/* name */ "T",
- /* matcher_indices */ MatcherIndicesIndex(56),
+ /* matcher_indices */ MatcherIndicesIndex(58),
/* kind */ TemplateInfo::Kind::kType,
},
{
@@ -879,7 +886,7 @@
{
/* [7] */
/* name */ "T",
- /* matcher_indices */ MatcherIndicesIndex(53),
+ /* matcher_indices */ MatcherIndicesIndex(55),
/* kind */ TemplateInfo::Kind::kType,
},
{
@@ -891,7 +898,7 @@
{
/* [9] */
/* name */ "T",
- /* matcher_indices */ MatcherIndicesIndex(54),
+ /* matcher_indices */ MatcherIndicesIndex(56),
/* kind */ TemplateInfo::Kind::kType,
},
{
@@ -903,7 +910,7 @@
{
/* [11] */
/* name */ "T",
- /* matcher_indices */ MatcherIndicesIndex(55),
+ /* matcher_indices */ MatcherIndicesIndex(57),
/* kind */ TemplateInfo::Kind::kType,
},
{
@@ -912,6 +919,12 @@
/* matcher_indices */ MatcherIndicesIndex(/* invalid */),
/* kind */ TemplateInfo::Kind::kNumber,
},
+ {
+ /* [13] */
+ /* name */ "A",
+ /* matcher_indices */ MatcherIndicesIndex(/* invalid */),
+ /* kind */ TemplateInfo::Kind::kNumber,
+ },
};
static_assert(TemplateIndex::CanIndex(kTemplates),
@@ -925,7 +938,7 @@
/* num_explicit_templates */ 0,
/* num_templates */ 3,
/* templates */ TemplateIndex(4),
- /* parameters */ ParameterIndex(14),
+ /* parameters */ ParameterIndex(12),
/* return_matcher_indices */ MatcherIndicesIndex(34),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
@@ -936,7 +949,7 @@
/* num_explicit_templates */ 0,
/* num_templates */ 3,
/* templates */ TemplateIndex(4),
- /* parameters */ ParameterIndex(16),
+ /* parameters */ ParameterIndex(14),
/* return_matcher_indices */ MatcherIndicesIndex(19),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
@@ -947,7 +960,7 @@
/* num_explicit_templates */ 0,
/* num_templates */ 4,
/* templates */ TemplateIndex(0),
- /* parameters */ ParameterIndex(18),
+ /* parameters */ ParameterIndex(16),
/* return_matcher_indices */ MatcherIndicesIndex(4),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
@@ -958,7 +971,7 @@
/* num_explicit_templates */ 0,
/* num_templates */ 1,
/* templates */ TemplateIndex(7),
- /* parameters */ ParameterIndex(20),
+ /* parameters */ ParameterIndex(22),
/* return_matcher_indices */ MatcherIndicesIndex(18),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
@@ -969,7 +982,7 @@
/* num_explicit_templates */ 0,
/* num_templates */ 2,
/* templates */ TemplateIndex(7),
- /* parameters */ ParameterIndex(15),
+ /* parameters */ ParameterIndex(13),
/* return_matcher_indices */ MatcherIndicesIndex(16),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
@@ -980,7 +993,7 @@
/* num_explicit_templates */ 0,
/* num_templates */ 1,
/* templates */ TemplateIndex(9),
- /* parameters */ ParameterIndex(20),
+ /* parameters */ ParameterIndex(22),
/* return_matcher_indices */ MatcherIndicesIndex(24),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
@@ -991,7 +1004,7 @@
/* num_explicit_templates */ 0,
/* num_templates */ 2,
/* templates */ TemplateIndex(9),
- /* parameters */ ParameterIndex(15),
+ /* parameters */ ParameterIndex(13),
/* return_matcher_indices */ MatcherIndicesIndex(22),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
@@ -1002,7 +1015,7 @@
/* num_explicit_templates */ 0,
/* num_templates */ 1,
/* templates */ TemplateIndex(11),
- /* parameters */ ParameterIndex(20),
+ /* parameters */ ParameterIndex(22),
/* return_matcher_indices */ MatcherIndicesIndex(27),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
@@ -1013,7 +1026,7 @@
/* num_explicit_templates */ 0,
/* num_templates */ 2,
/* templates */ TemplateIndex(11),
- /* parameters */ ParameterIndex(15),
+ /* parameters */ ParameterIndex(13),
/* return_matcher_indices */ MatcherIndicesIndex(25),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
@@ -1024,7 +1037,7 @@
/* num_explicit_templates */ 0,
/* num_templates */ 0,
/* templates */ TemplateIndex(/* invalid */),
- /* parameters */ ParameterIndex(21),
+ /* parameters */ ParameterIndex(23),
/* return_matcher_indices */ MatcherIndicesIndex(24),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
@@ -1035,7 +1048,7 @@
/* num_explicit_templates */ 0,
/* num_templates */ 1,
/* templates */ TemplateIndex(8),
- /* parameters */ ParameterIndex(22),
+ /* parameters */ ParameterIndex(24),
/* return_matcher_indices */ MatcherIndicesIndex(28),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
@@ -1046,7 +1059,7 @@
/* num_explicit_templates */ 0,
/* num_templates */ 0,
/* templates */ TemplateIndex(/* invalid */),
- /* parameters */ ParameterIndex(23),
+ /* parameters */ ParameterIndex(21),
/* return_matcher_indices */ MatcherIndicesIndex(27),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
@@ -1057,7 +1070,7 @@
/* num_explicit_templates */ 0,
/* num_templates */ 1,
/* templates */ TemplateIndex(8),
- /* parameters */ ParameterIndex(24),
+ /* parameters */ ParameterIndex(25),
/* return_matcher_indices */ MatcherIndicesIndex(31),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
@@ -1068,7 +1081,7 @@
/* num_explicit_templates */ 0,
/* num_templates */ 0,
/* templates */ TemplateIndex(/* invalid */),
- /* parameters */ ParameterIndex(12),
+ /* parameters */ ParameterIndex(18),
/* return_matcher_indices */ MatcherIndicesIndex(24),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
@@ -1079,7 +1092,7 @@
/* num_explicit_templates */ 0,
/* num_templates */ 0,
/* templates */ TemplateIndex(/* invalid */),
- /* parameters */ ParameterIndex(12),
+ /* parameters */ ParameterIndex(18),
/* return_matcher_indices */ MatcherIndicesIndex(39),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
@@ -1090,7 +1103,7 @@
/* num_explicit_templates */ 0,
/* num_templates */ 0,
/* templates */ TemplateIndex(/* invalid */),
- /* parameters */ ParameterIndex(12),
+ /* parameters */ ParameterIndex(18),
/* return_matcher_indices */ MatcherIndicesIndex(41),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
@@ -1101,7 +1114,7 @@
/* num_explicit_templates */ 0,
/* num_templates */ 0,
/* templates */ TemplateIndex(/* invalid */),
- /* parameters */ ParameterIndex(12),
+ /* parameters */ ParameterIndex(18),
/* return_matcher_indices */ MatcherIndicesIndex(43),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
@@ -1112,7 +1125,7 @@
/* num_explicit_templates */ 0,
/* num_templates */ 0,
/* templates */ TemplateIndex(/* invalid */),
- /* parameters */ ParameterIndex(12),
+ /* parameters */ ParameterIndex(18),
/* return_matcher_indices */ MatcherIndicesIndex(46),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
@@ -1123,7 +1136,7 @@
/* num_explicit_templates */ 0,
/* num_templates */ 0,
/* templates */ TemplateIndex(/* invalid */),
- /* parameters */ ParameterIndex(12),
+ /* parameters */ ParameterIndex(18),
/* return_matcher_indices */ MatcherIndicesIndex(45),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
@@ -1134,7 +1147,7 @@
/* num_explicit_templates */ 0,
/* num_templates */ 0,
/* templates */ TemplateIndex(/* invalid */),
- /* parameters */ ParameterIndex(12),
+ /* parameters */ ParameterIndex(18),
/* return_matcher_indices */ MatcherIndicesIndex(47),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
@@ -1145,7 +1158,7 @@
/* num_explicit_templates */ 0,
/* num_templates */ 0,
/* templates */ TemplateIndex(/* invalid */),
- /* parameters */ ParameterIndex(12),
+ /* parameters */ ParameterIndex(18),
/* return_matcher_indices */ MatcherIndicesIndex(49),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
@@ -1193,6 +1206,17 @@
/* return_matcher_indices */ MatcherIndicesIndex(/* invalid */),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
+ {
+ /* [25] */
+ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline, OverloadFlag::kMemberFunction),
+ /* num_parameters */ 2,
+ /* num_explicit_templates */ 0,
+ /* num_templates */ 1,
+ /* templates */ TemplateIndex(13),
+ /* parameters */ ParameterIndex(20),
+ /* return_matcher_indices */ MatcherIndicesIndex(/* invalid */),
+ /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
+ },
};
static_assert(OverloadIndex::CanIndex(kOverloads),
@@ -1236,84 +1260,90 @@
},
{
/* [5] */
- /* fn Load(byte_address_buffer<readable>, offset: u32) -> u32 */
- /* num overloads */ 1,
- /* overloads */ OverloadIndex(13),
- },
- {
- /* [6] */
- /* fn Load2(byte_address_buffer<readable>, offset: u32) -> vec2<u32> */
- /* num overloads */ 1,
- /* overloads */ OverloadIndex(14),
- },
- {
- /* [7] */
- /* fn Load3(byte_address_buffer<readable>, offset: u32) -> vec3<u32> */
- /* num overloads */ 1,
- /* overloads */ OverloadIndex(15),
- },
- {
- /* [8] */
- /* fn Load4(byte_address_buffer<readable>, offset: u32) -> vec4<u32> */
- /* num overloads */ 1,
- /* overloads */ OverloadIndex(16),
- },
- {
- /* [9] */
- /* fn LoadF16(byte_address_buffer<readable>, offset: u32) -> f16 */
- /* num overloads */ 1,
- /* overloads */ OverloadIndex(17),
- },
- {
- /* [10] */
- /* fn Load2F16(byte_address_buffer<readable>, offset: u32) -> vec2<f16> */
- /* num overloads */ 1,
- /* overloads */ OverloadIndex(18),
- },
- {
- /* [11] */
- /* fn Load3F16(byte_address_buffer<readable>, offset: u32) -> vec3<f16> */
- /* num overloads */ 1,
- /* overloads */ OverloadIndex(19),
- },
- {
- /* [12] */
- /* fn Load4F16(byte_address_buffer<readable>, offset: u32) -> vec4<f16> */
- /* num overloads */ 1,
- /* overloads */ OverloadIndex(20),
- },
- {
- /* [13] */
- /* fn Store(byte_address_buffer<writable>, offset: u32, value: u32) */
- /* num overloads */ 1,
- /* overloads */ OverloadIndex(21),
- },
- {
- /* [14] */
- /* fn Store2(byte_address_buffer<writable>, offset: u32, value: vec2<u32>) */
- /* num overloads */ 1,
- /* overloads */ OverloadIndex(22),
- },
- {
- /* [15] */
- /* fn Store3(byte_address_buffer<writable>, offset: u32, value: vec3<u32>) */
- /* num overloads */ 1,
- /* overloads */ OverloadIndex(23),
- },
- {
- /* [16] */
- /* fn Store4(byte_address_buffer<writable>, offset: u32, value: vec4<u32>) */
- /* num overloads */ 1,
- /* overloads */ OverloadIndex(24),
- },
- {
- /* [17] */
/* fn mul[T : f32_f16, C : num, R : num](mat<C, R, T>, vec<C, T>) -> vec<R, T> */
/* fn mul[T : f32_f16, C : num, R : num](vec<R, T>, mat<C, R, T>) -> vec<C, T> */
/* fn mul[T : f32_f16, K : num, C : num, R : num](mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T> */
/* num overloads */ 3,
/* overloads */ OverloadIndex(0),
},
+ {
+ /* [6] */
+ /* fn Load(byte_address_buffer<readable>, offset: u32) -> u32 */
+ /* num overloads */ 1,
+ /* overloads */ OverloadIndex(13),
+ },
+ {
+ /* [7] */
+ /* fn Load2(byte_address_buffer<readable>, offset: u32) -> vec2<u32> */
+ /* num overloads */ 1,
+ /* overloads */ OverloadIndex(14),
+ },
+ {
+ /* [8] */
+ /* fn Load3(byte_address_buffer<readable>, offset: u32) -> vec3<u32> */
+ /* num overloads */ 1,
+ /* overloads */ OverloadIndex(15),
+ },
+ {
+ /* [9] */
+ /* fn Load4(byte_address_buffer<readable>, offset: u32) -> vec4<u32> */
+ /* num overloads */ 1,
+ /* overloads */ OverloadIndex(16),
+ },
+ {
+ /* [10] */
+ /* fn LoadF16(byte_address_buffer<readable>, offset: u32) -> f16 */
+ /* num overloads */ 1,
+ /* overloads */ OverloadIndex(17),
+ },
+ {
+ /* [11] */
+ /* fn Load2F16(byte_address_buffer<readable>, offset: u32) -> vec2<f16> */
+ /* num overloads */ 1,
+ /* overloads */ OverloadIndex(18),
+ },
+ {
+ /* [12] */
+ /* fn Load3F16(byte_address_buffer<readable>, offset: u32) -> vec3<f16> */
+ /* num overloads */ 1,
+ /* overloads */ OverloadIndex(19),
+ },
+ {
+ /* [13] */
+ /* fn Load4F16(byte_address_buffer<readable>, offset: u32) -> vec4<f16> */
+ /* num overloads */ 1,
+ /* overloads */ OverloadIndex(20),
+ },
+ {
+ /* [14] */
+ /* fn Store(byte_address_buffer<writable>, offset: u32, value: u32) */
+ /* num overloads */ 1,
+ /* overloads */ OverloadIndex(21),
+ },
+ {
+ /* [15] */
+ /* fn Store2(byte_address_buffer<writable>, offset: u32, value: vec2<u32>) */
+ /* num overloads */ 1,
+ /* overloads */ OverloadIndex(22),
+ },
+ {
+ /* [16] */
+ /* fn Store3(byte_address_buffer<writable>, offset: u32, value: vec3<u32>) */
+ /* num overloads */ 1,
+ /* overloads */ OverloadIndex(23),
+ },
+ {
+ /* [17] */
+ /* fn Store4(byte_address_buffer<writable>, offset: u32, value: vec4<u32>) */
+ /* num overloads */ 1,
+ /* overloads */ OverloadIndex(24),
+ },
+ {
+ /* [18] */
+ /* fn GetDimensions[A : access](byte_address_buffer<A>, u32) */
+ /* num overloads */ 1,
+ /* overloads */ OverloadIndex(25),
+ },
};
// clang-format on
diff --git a/src/tint/lang/hlsl/writer/access_test.cc b/src/tint/lang/hlsl/writer/access_test.cc
index a47d139..cc47037 100644
--- a/src/tint/lang/hlsl/writer/access_test.cc
+++ b/src/tint/lang/hlsl/writer/access_test.cc
@@ -686,7 +686,7 @@
)");
}
-TEST_F(HlslWriterTest, DISABLED_AccessComplexDynamicAccessChain) {
+TEST_F(HlslWriterTest, AccessComplexDynamicAccessChain) {
auto* S1 = ty.Struct(mod.symbols.New("S1"), {
{mod.symbols.New("a"), ty.i32()},
{mod.symbols.New("b"), ty.vec3<f32>()},
@@ -721,13 +721,70 @@
ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
EXPECT_EQ(output_.hlsl, R"(
-ByteAddressBuffer v : register(t0);
-void m() {
+RWByteAddressBuffer sb : register(u0);
+void foo() {
int i = 4;
+ int v = i;
uint j = 1u;
+ uint v_1 = j;
int k = 2;
- float x = asfloat(v.Load((((((16u + (128u * uint(i))) + 16u) + (32u * j)) + 16u) + (4u * uint(k)))));
+ int v_2 = k;
+ uint v_3 = 0u;
+ sb.GetDimensions(v_3);
+ uint v_4 = min(uint(v), (((v_3 - 16u) / 128u) - 1u));
+ uint v_5 = min(v_1, 2u);
+ uint v_6 = (uint(v_4) * 128u);
+ uint v_7 = (uint(v_5) * 32u);
+ float x = asfloat(sb.Load((((48u + v_6) + v_7) + (uint(min(uint(v_2), 2u)) * 4u))));
}
+
+)");
+}
+
+TEST_F(HlslWriterTest, AccessComplexDynamicAccessChainSplit) {
+ auto* S1 = ty.Struct(mod.symbols.New("S1"), {
+ {mod.symbols.New("a"), ty.i32()},
+ {mod.symbols.New("b"), ty.vec3<f32>()},
+ {mod.symbols.New("c"), ty.i32()},
+ });
+ auto* S2 = ty.Struct(mod.symbols.New("S2"), {
+ {mod.symbols.New("a"), ty.i32()},
+ {mod.symbols.New("b"), ty.array(S1, 3)},
+ {mod.symbols.New("c"), ty.i32()},
+ });
+
+ auto* SB = ty.Struct(mod.symbols.New("SB"), {
+ {mod.symbols.New("a"), ty.i32()},
+ {mod.symbols.New("b"), ty.runtime_array(S2)},
+ });
+
+ auto* var = b.Var("sb", storage, SB, core::Access::kReadWrite);
+ var->SetBindingPoint(0, 0);
+
+ b.ir.root_block->Append(var);
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ auto* j = b.Load(b.Var("j", 1_u));
+ b.Let("x", b.LoadVectorElement(b.Access(ty.ptr<storage, vec3<f32>, read_write>(), var, 1_u,
+ 4_u, 1_u, j, 1_u),
+ 2_u));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+RWByteAddressBuffer sb : register(u0);
+void foo() {
+ uint j = 1u;
+ uint v = j;
+ uint v_1 = 0u;
+ sb.GetDimensions(v_1);
+ uint v_2 = min(4u, (((v_1 - 16u) / 128u) - 1u));
+ uint v_3 = min(v, 2u);
+ uint v_4 = (uint(v_2) * 128u);
+ float x = asfloat(sb.Load(((56u + v_4) + (uint(v_3) * 32u))));
+}
+
)");
}
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_memory_access.cc b/src/tint/lang/hlsl/writer/raise/decompose_memory_access.cc
index f7d2e86..58ea58b 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_memory_access.cc
+++ b/src/tint/lang/hlsl/writer/raise/decompose_memory_access.cc
@@ -111,7 +111,10 @@
[&](core::ir::StoreVectorElement* s) { StoreVectorElement(s, var, var_ty); },
[&](core::ir::Store* s) { Store(s); }, //
[&](core::ir::Load* l) { Load(l, var); }, //
- [&](core::ir::Access* a) { Access(a, var, a->Object()->Type(), 0u); },
+ [&](core::ir::Access* a) {
+ OffsetData offset;
+ Access(a, var, a->Object()->Type(), &offset);
+ },
[&](core::ir::Let* let) {
// The `let` is, essentially, an alias for the `var` as it's assigned
// directly. Gather all the `let` usages into our worklist, and then replace
@@ -130,13 +133,32 @@
}
}
- uint32_t CalculateVectorIndex(core::ir::Value* v, const core::type::Type* store_ty) {
- auto* idx_value = v->As<core::ir::Constant>();
+ struct OffsetData {
+ uint32_t byte_offset = 0;
+ Vector<core::ir::Value*, 4> expr{};
+ };
- // TODO(dsinclair): Handle non-constant vector indices.
- TINT_ASSERT(idx_value);
+ // Note, must be called inside a builder insert block (Append, InsertBefore, etc)
+ void UpdateOffsetData(core::ir::Value* v, uint32_t elm_size, OffsetData* offset) {
+ tint::Switch(
+ v, //
+ [&](core::ir::Constant* idx_value) {
+ offset->byte_offset += idx_value->Value()->ValueAs<uint32_t>() * elm_size;
+ },
+ [&](core::ir::Value* val) {
+ offset->expr.Push(
+ b.Multiply(ty.u32(), b.Convert(ty.u32(), val), u32(elm_size))->Result(0));
+ },
+ TINT_ICE_ON_NO_MATCH);
+ }
- return idx_value->Value()->ValueAs<uint32_t>() * store_ty->DeepestElement()->Size();
+ // Note, must be called inside a builder insert block (Append, InsertBefore, etc)
+ core::ir::Value* OffsetToValue(OffsetData offset) {
+ core::ir::Value* val = b.Value(u32(offset.byte_offset));
+ for (core::ir::Value* expr : offset.expr) {
+ val = b.Add(ty.u32(), val, expr)->Result(0);
+ }
+ return val;
}
// Creates the appropriate load instructions for the given result type.
@@ -165,13 +187,13 @@
TINT_ICE_ON_NO_MATCH);
}
- // Creates a `v.Load{2,3,4} offset` call based on the provided type. The load returns a `u32` or
- // vector of `u32` and then a `bitcast` is done to get back to the desired type.
+ // Creates a `v.Load{2,3,4} offset` call based on the provided type. The load returns a
+ // `u32` or vector of `u32` and then a `bitcast` is done to get back to the desired type.
//
// This only works for `u32`, `i32`, `f16`, `f32` and the vector sizes of those types.
//
- // The `f16` type is special in that `f16` uses a templated load in HLSL `Load<float16_t>` and
- // returns the correct type, so there is no bitcast.
+ // The `f16` type is special in that `f16` uses a templated load in HLSL `Load<float16_t>`
+ // and returns the correct type, so there is no bitcast.
core::ir::Call* MakeScalarOrVectorLoad(core::ir::Var* var,
const core::type::Type* result_ty,
core::ir::Value* offset) {
@@ -213,8 +235,8 @@
return res;
}
- // Creates a load function for the given `var` and `struct` combination. Essentially creates a
- // function similar to:
+ // Creates a load function for the given `var` and `struct` combination. Essentially creates
+ // a function similar to:
//
// fn custom_load_S(offset: u32) {
// let a = <load S member 0>(offset + member 0 offset);
@@ -246,8 +268,8 @@
});
}
- // Creates a load function for the given `var` and `matrix` combination. Essentially creates a
- // function similar to:
+ // Creates a load function for the given `var` and `matrix` combination. Essentially creates
+ // a function similar to:
//
// fn custom_load_M(offset: u32) {
// let a = <load M column 1>(offset + (1 * ColumnStride));
@@ -279,8 +301,8 @@
});
}
- // Creates a load function for the given `var` and `array` combination. Essentially creates a
- // function similar to:
+ // Creates a load function for the given `var` and `array` combination. Essentially creates
+ // a function similar to:
//
// fn custom_load_A(offset: u32) {
// A a = A();
@@ -322,10 +344,10 @@
});
}
- void InsertLoad(core::ir::Var* var, core::ir::Instruction* inst, uint32_t offset) {
+ void InsertLoad(core::ir::Var* var, core::ir::Instruction* inst, OffsetData offset) {
b.InsertBefore(inst, [&] {
auto* call =
- MakeLoad(inst, var, inst->Result(0)->Type()->UnwrapPtr(), b.Value(u32(offset)));
+ MakeLoad(inst, var, inst->Result(0)->Type()->UnwrapPtr(), OffsetToValue(offset));
inst->Result(0)->ReplaceAllUsesWith(call->Result(0));
});
inst->Destroy();
@@ -334,7 +356,9 @@
void Access(core::ir::Access* a,
core::ir::Var* var,
const core::type::Type* obj,
- uint32_t byte_offset) {
+ OffsetData* offset) {
+ TINT_ASSERT(offset);
+
// Note, because we recurse through the `access` helper, the object passed in isn't
// necessarily the originating `var` object, but maybe a partially resolved access chain
// object.
@@ -343,29 +367,31 @@
}
for (auto* idx_value : a->Indices()) {
- auto* cnst = idx_value->As<core::ir::Constant>();
-
- // TODO(dsinclair): Handle non-constant accessors where the indices are dynamic
- TINT_ASSERT(cnst);
-
- uint32_t idx = cnst->Value()->ValueAs<uint32_t>();
tint::Switch(
obj, //
[&](const core::type::Vector* v) {
- byte_offset += v->type()->Size() * idx;
+ b.InsertBefore(a,
+ [&] { UpdateOffsetData(idx_value, v->type()->Size(), offset); });
obj = v->type();
},
[&](const core::type::Matrix* m) {
- byte_offset += m->type()->Size() * m->rows() * idx;
+ b.InsertBefore(a,
+ [&] { UpdateOffsetData(idx_value, m->ColumnStride(), offset); });
obj = m->ColumnType();
},
[&](const core::type::Array* ary) {
- byte_offset += ary->Stride() * idx;
+ b.InsertBefore(a, [&] { UpdateOffsetData(idx_value, ary->Stride(), offset); });
obj = ary->ElemType();
},
[&](const core::type::Struct* s) {
+ auto* cnst = idx_value->As<core::ir::Constant>();
+
+ // A struct index must be a constant
+ TINT_ASSERT(cnst);
+
+ uint32_t idx = cnst->Value()->ValueAs<uint32_t>();
auto* mem = s->Members()[idx];
- byte_offset += mem->Offset();
+ offset->byte_offset += mem->Offset();
obj = mem->Type();
},
TINT_ICE_ON_NO_MATCH);
@@ -378,8 +404,9 @@
tint::Switch(
usage.instruction,
[&](core::ir::Let* let) {
- // The `let` is essentially an alias to the `access`. So, add the `let` usages
- // into the usage worklist, and replace the let with the access chain directly.
+ // The `let` is essentially an alias to the `access`. So, add the `let`
+ // usages into the usage worklist, and replace the let with the access chain
+ // directly.
for (auto& u : let->Result(0)->Usages()) {
usages.Push(u);
}
@@ -388,28 +415,58 @@
},
[&](core::ir::Access* sub_access) {
// Treat an access chain of the access chain as a continuation of the outer
- // chain. Pass through the object we stopped at and the current byte_offset and
- // then restart the access chain replacement for the new access chain.
- Access(sub_access, var, obj, byte_offset);
+ // chain. Pass through the object we stopped at and the current byte_offset
+ // and then restart the access chain replacement for the new access chain.
+ Access(sub_access, var, obj, offset);
},
[&](core::ir::LoadVectorElement* lve) {
a->Result(0)->RemoveUsage(usage);
- byte_offset += CalculateVectorIndex(lve->Index(), obj);
- InsertLoad(var, lve, byte_offset);
+ b.InsertBefore(lve, [&] {
+ UpdateOffsetData(lve->Index(), obj->DeepestElement()->Size(), offset);
+ });
+ InsertLoad(var, lve, *offset);
},
[&](core::ir::Load* ld) {
a->Result(0)->RemoveUsage(usage);
- InsertLoad(var, ld, byte_offset);
+ InsertLoad(var, ld, *offset);
},
[&](core::ir::StoreVectorElement*) {
- // TODO(dsinclair): Handle stor vector elements
+ // TODO(dsinclair): Handle store vector elements
}, //
[&](core::ir::Store*) {
// TODO(dsinclair): Handle store
}, //
+ [&](core::ir::CoreBuiltinCall* call) {
+ // Array length calls require the access
+ TINT_ASSERT(call->Func() == core::BuiltinFn::kArrayLength);
+
+ // If this access chain is being used in an `arrayLength` call then the access
+ // chain _must_ have resolved to the runtime array member of the structure. So,
+ // we _must_ have set `obj` to the array member which is a runtime array.
+ auto* arr_ty = obj->As<core::type::Array>();
+ TINT_ASSERT(arr_ty && arr_ty->Count()->Is<core::type::RuntimeArrayCount>());
+
+ b.InsertBefore(a, [&] {
+ auto* val = b.Let(ty.u32());
+ val->SetValue(b.Zero<u32>());
+
+ b.MemberCall<hlsl::ir::MemberBuiltinCall>(
+ ty.void_(), BuiltinFn::kGetDimensions, var, val);
+
+ // Because the `runtime_array` must be the last element of the outer most
+ // structure and we're calling `arrayLength` on the array then the access
+ // chain must have only had a single item in it and that item must have been
+ // a constant offset.
+ auto* div =
+ b.Divide(ty.u32(), b.Subtract(ty.u32(), val, u32(offset->byte_offset)),
+ u32(arr_ty->Stride()));
+ call->Result(0)->ReplaceAllUsesWith(div->Result(0));
+ });
+ call->Destroy();
+ }, //
TINT_ICE_ON_NO_MATCH);
}
@@ -420,8 +477,8 @@
// TODO(dsinclair): Handle store
}
- // This should _only_ be handling a `var` parameter as any `access` parameters would have been
- // replaced by the `access` being converted.
+ // This should _only_ be handling a `var` parameter as any `access` parameters would have
+ // been replaced by the `access` being converted.
void Load(core::ir::Load* ld, core::ir::Var* var) {
auto* result = ld->From()->As<core::ir::InstructionResult>();
TINT_ASSERT(result);
@@ -445,10 +502,12 @@
void LoadVectorElement(core::ir::LoadVectorElement* lve,
core::ir::Var* var,
const core::type::Pointer* var_ty) {
- uint32_t pos = CalculateVectorIndex(lve->Index(), var_ty->StoreType());
-
b.InsertBefore(lve, [&] {
- auto* result = MakeScalarOrVectorLoad(var, lve->Result(0)->Type(), b.Value(u32(pos)));
+ OffsetData offset{};
+ UpdateOffsetData(lve->Index(), var_ty->StoreType()->DeepestElement()->Size(), &offset);
+
+ auto* result =
+ MakeScalarOrVectorLoad(var, lve->Result(0)->Type(), OffsetToValue(offset));
lve->Result(0)->ReplaceAllUsesWith(result->Result(0));
});
@@ -463,14 +522,14 @@
void StoreVectorElement(core::ir::StoreVectorElement* sve,
core::ir::Var* var,
const core::type::Pointer* var_ty) {
- uint32_t pos = CalculateVectorIndex(sve->Index(), var_ty->StoreType());
+ b.InsertBefore(sve, [&] {
+ OffsetData offset{};
+ UpdateOffsetData(sve->Index(), var_ty->StoreType()->DeepestElement()->Size(), &offset);
- auto* cast = b.Bitcast(ty.u32(), sve->Value());
- auto* builtin = b.MemberCall<hlsl::ir::MemberBuiltinCall>(ty.void_(), BuiltinFn::kStore,
- var, u32(pos), cast);
-
- cast->InsertBefore(sve);
- builtin->InsertBefore(sve);
+ auto* cast = b.Bitcast(ty.u32(), sve->Value());
+ b.MemberCall<hlsl::ir::MemberBuiltinCall>(ty.void_(), BuiltinFn::kStore, var,
+ OffsetToValue(offset), cast);
+ });
sve->Destroy();
}
};
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_memory_access_test.cc b/src/tint/lang/hlsl/writer/raise/decompose_memory_access_test.cc
index bc2be08..c6b1662 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_memory_access_test.cc
+++ b/src/tint/lang/hlsl/writer/raise/decompose_memory_access_test.cc
@@ -1074,7 +1074,7 @@
EXPECT_EQ(expect, str());
}
-TEST_F(HlslWriterDecomposeMemoryAccessTest, DISABLED_ComplexDynamicAccessChain) {
+TEST_F(HlslWriterDecomposeMemoryAccessTest, ComplexDynamicAccessChain) {
auto* S1 = ty.Struct(mod.symbols.New("S1"), {
{mod.symbols.New("a"), ty.i32()},
{mod.symbols.New("b"), ty.vec3<f32>()},
@@ -1176,8 +1176,125 @@
%6:u32 = load %j
%k:ptr<function, i32, read_write> = var, 2i
%8:i32 = load %k
- %10:f32 = bitcast %9
- %x:f32 = let %10
+ %9:u32 = convert %4
+ %10:u32 = mul %9, 128u
+ %11:u32 = convert %6
+ %12:u32 = mul %11, 32u
+ %13:u32 = convert %8
+ %14:u32 = mul %13, 4u
+ %15:u32 = add 48u, %10
+ %16:u32 = add %15, %12
+ %17:u32 = add %16, %14
+ %18:u32 = %sb.Load %17
+ %19:f32 = bitcast %18
+ %x:f32 = let %19
+ ret
+ }
+}
+)";
+
+ Run(DecomposeMemoryAccess);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterDecomposeMemoryAccessTest, ComplexDynamicAccessChainDynamicAccessInMiddle) {
+ auto* S1 = ty.Struct(mod.symbols.New("S1"), {
+ {mod.symbols.New("a"), ty.i32()},
+ {mod.symbols.New("b"), ty.vec3<f32>()},
+ {mod.symbols.New("c"), ty.i32()},
+ });
+ auto* S2 = ty.Struct(mod.symbols.New("S2"), {
+ {mod.symbols.New("a"), ty.i32()},
+ {mod.symbols.New("b"), ty.array(S1, 3)},
+ {mod.symbols.New("c"), ty.i32()},
+ });
+
+ auto* SB = ty.Struct(mod.symbols.New("SB"), {
+ {mod.symbols.New("a"), ty.i32()},
+ {mod.symbols.New("b"), ty.runtime_array(S2)},
+ });
+
+ auto* var = b.Var("sb", storage, SB, core::Access::kReadWrite);
+ var->SetBindingPoint(0, 0);
+
+ b.ir.root_block->Append(var);
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ auto* j = b.Load(b.Var("j", 1_u));
+ // let x : f32 = sb.b[4].b[j].b[2];
+ b.Let("x", b.LoadVectorElement(b.Access(ty.ptr<storage, vec3<f32>, read_write>(), var, 1_u,
+ 4_u, 1_u, j, 1_u),
+ 2_u));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+S1 = struct @align(16) {
+ a:i32 @offset(0)
+ b:vec3<f32> @offset(16)
+ c:i32 @offset(28)
+}
+
+S2 = struct @align(16) {
+ a_1:i32 @offset(0)
+ b_1:array<S1, 3> @offset(16)
+ c_1:i32 @offset(112)
+}
+
+SB = struct @align(16) {
+ a_2:i32 @offset(0)
+ b_2:array<S2> @offset(16)
+}
+
+$B1: { # root
+ %sb:ptr<storage, SB, read_write> = var @binding_point(0, 0)
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %j:ptr<function, u32, read_write> = var, 1u
+ %4:u32 = load %j
+ %5:ptr<storage, vec3<f32>, read_write> = access %sb, 1u, 4u, 1u, %4, 1u
+ %6:f32 = load_vector_element %5, 2u
+ %x:f32 = let %6
+ ret
+ }
+}
+)";
+ ASSERT_EQ(src, str());
+
+ auto* expect = R"(
+S1 = struct @align(16) {
+ a:i32 @offset(0)
+ b:vec3<f32> @offset(16)
+ c:i32 @offset(28)
+}
+
+S2 = struct @align(16) {
+ a_1:i32 @offset(0)
+ b_1:array<S1, 3> @offset(16)
+ c_1:i32 @offset(112)
+}
+
+SB = struct @align(16) {
+ a_2:i32 @offset(0)
+ b_2:array<S2> @offset(16)
+}
+
+$B1: { # root
+ %sb:hlsl.byte_address_buffer<read_write> = var @binding_point(0, 0)
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %j:ptr<function, u32, read_write> = var, 1u
+ %4:u32 = load %j
+ %5:u32 = convert %4
+ %6:u32 = mul %5, 32u
+ %7:u32 = add 568u, %6
+ %8:u32 = %sb.Load %7
+ %9:f32 = bitcast %8
+ %x:f32 = let %9
ret
}
}