blob: 803ae73aaee0f7112672f244b1cded79fe9f42c4 [file] [log] [blame]
Ryan Harrisondbc13af2022-02-21 15:19:07 +00001// Copyright 2021 The Tint Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "src/tint/transform/decompose_strided_matrix.h"
16
17#include <memory>
18#include <utility>
19#include <vector>
20
21#include "src/tint/ast/disable_validation_attribute.h"
22#include "src/tint/program_builder.h"
23#include "src/tint/transform/simplify_pointers.h"
24#include "src/tint/transform/test_helper.h"
25#include "src/tint/transform/unshadow.h"
26
Ben Clayton0ce9ab02022-05-05 20:23:40 +000027using namespace tint::number_suffixes; // NOLINT
28
dan sinclairb5599d32022-04-07 16:55:14 +000029namespace tint::transform {
Ryan Harrisondbc13af2022-02-21 15:19:07 +000030namespace {
31
32using DecomposeStridedMatrixTest = TransformTest;
Ryan Harrisondbc13af2022-02-21 15:19:07 +000033
34TEST_F(DecomposeStridedMatrixTest, ShouldRunEmptyModule) {
dan sinclair41e4d9a2022-05-01 14:40:55 +000035 auto* src = R"()";
Ryan Harrisondbc13af2022-02-21 15:19:07 +000036
dan sinclair41e4d9a2022-05-01 14:40:55 +000037 EXPECT_FALSE(ShouldRun<DecomposeStridedMatrix>(src));
Ryan Harrisondbc13af2022-02-21 15:19:07 +000038}
39
40TEST_F(DecomposeStridedMatrixTest, ShouldRunNonStridedMatrox) {
dan sinclair41e4d9a2022-05-01 14:40:55 +000041 auto* src = R"(
Ryan Harrisondbc13af2022-02-21 15:19:07 +000042var<private> m : mat3x2<f32>;
43)";
44
dan sinclair41e4d9a2022-05-01 14:40:55 +000045 EXPECT_FALSE(ShouldRun<DecomposeStridedMatrix>(src));
Ryan Harrisondbc13af2022-02-21 15:19:07 +000046}
47
48TEST_F(DecomposeStridedMatrixTest, Empty) {
dan sinclair41e4d9a2022-05-01 14:40:55 +000049 auto* src = R"()";
50 auto* expect = src;
Ryan Harrisondbc13af2022-02-21 15:19:07 +000051
dan sinclair41e4d9a2022-05-01 14:40:55 +000052 auto got = Run<DecomposeStridedMatrix>(src);
Ryan Harrisondbc13af2022-02-21 15:19:07 +000053
dan sinclair41e4d9a2022-05-01 14:40:55 +000054 EXPECT_EQ(expect, str(got));
Ryan Harrisondbc13af2022-02-21 15:19:07 +000055}
56
57TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix) {
dan sinclair41e4d9a2022-05-01 14:40:55 +000058 // struct S {
59 // @offset(16) @stride(32)
60 // @internal(ignore_stride_attribute)
61 // m : mat2x2<f32>,
62 // };
63 // @group(0) @binding(0) var<uniform> s : S;
64 //
65 // @stage(compute) @workgroup_size(1)
66 // fn f() {
67 // let x : mat2x2<f32> = s.m;
68 // }
69 ProgramBuilder b;
70 auto* S = b.Structure(
71 "S", {
72 b.Member("m", b.ty.mat2x2<f32>(),
73 {
74 b.create<ast::StructMemberOffsetAttribute>(16),
75 b.create<ast::StrideAttribute>(32),
76 b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
77 }),
78 });
79 b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform, b.GroupAndBinding(0, 0));
80 b.Func("f", {}, b.ty.void_(),
81 {
82 b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
83 },
84 {
85 b.Stage(ast::PipelineStage::kCompute),
Ben Clayton0ce9ab02022-05-05 20:23:40 +000086 b.WorkgroupSize(1_i),
dan sinclair41e4d9a2022-05-01 14:40:55 +000087 });
Ryan Harrisondbc13af2022-02-21 15:19:07 +000088
dan sinclair41e4d9a2022-05-01 14:40:55 +000089 auto* expect = R"(
Ryan Harrisondbc13af2022-02-21 15:19:07 +000090struct S {
91 @size(16)
James Price3b671cb2022-03-28 14:31:22 +000092 padding : u32,
93 m : @stride(32) array<vec2<f32>, 2u>,
Ryan Harrisondbc13af2022-02-21 15:19:07 +000094}
95
96@group(0) @binding(0) var<uniform> s : S;
97
98fn arr_to_mat2x2_stride_32(arr : @stride(32) array<vec2<f32>, 2u>) -> mat2x2<f32> {
99 return mat2x2<f32>(arr[0u], arr[1u]);
100}
101
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000102@stage(compute) @workgroup_size(1i)
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000103fn f() {
104 let x : mat2x2<f32> = arr_to_mat2x2_stride_32(s.m);
105}
106)";
107
dan sinclair41e4d9a2022-05-01 14:40:55 +0000108 auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000109
dan sinclair41e4d9a2022-05-01 14:40:55 +0000110 EXPECT_EQ(expect, str(got));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000111}
112
113TEST_F(DecomposeStridedMatrixTest, ReadUniformColumn) {
dan sinclair41e4d9a2022-05-01 14:40:55 +0000114 // struct S {
115 // @offset(16) @stride(32)
116 // @internal(ignore_stride_attribute)
117 // m : mat2x2<f32>,
118 // };
119 // @group(0) @binding(0) var<uniform> s : S;
120 //
121 // @stage(compute) @workgroup_size(1)
122 // fn f() {
123 // let x : vec2<f32> = s.m[1];
124 // }
125 ProgramBuilder b;
126 auto* S = b.Structure(
127 "S", {
128 b.Member("m", b.ty.mat2x2<f32>(),
129 {
130 b.create<ast::StructMemberOffsetAttribute>(16),
131 b.create<ast::StrideAttribute>(32),
132 b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
133 }),
134 });
135 b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform, b.GroupAndBinding(0, 0));
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000136 b.Func(
137 "f", {}, b.ty.void_(),
138 {
139 b.Decl(b.Let("x", b.ty.vec2<f32>(), b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i))),
140 },
141 {
142 b.Stage(ast::PipelineStage::kCompute),
143 b.WorkgroupSize(1_i),
144 });
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000145
dan sinclair41e4d9a2022-05-01 14:40:55 +0000146 auto* expect = R"(
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000147struct S {
148 @size(16)
James Price3b671cb2022-03-28 14:31:22 +0000149 padding : u32,
150 m : @stride(32) array<vec2<f32>, 2u>,
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000151}
152
153@group(0) @binding(0) var<uniform> s : S;
154
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000155@stage(compute) @workgroup_size(1i)
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000156fn f() {
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000157 let x : vec2<f32> = s.m[1i];
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000158}
159)";
160
dan sinclair41e4d9a2022-05-01 14:40:55 +0000161 auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000162
dan sinclair41e4d9a2022-05-01 14:40:55 +0000163 EXPECT_EQ(expect, str(got));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000164}
165
166TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix_DefaultStride) {
dan sinclair41e4d9a2022-05-01 14:40:55 +0000167 // struct S {
168 // @offset(16) @stride(8)
169 // @internal(ignore_stride_attribute)
170 // m : mat2x2<f32>,
171 // };
172 // @group(0) @binding(0) var<uniform> s : S;
173 //
174 // @stage(compute) @workgroup_size(1)
175 // fn f() {
176 // let x : mat2x2<f32> = s.m;
177 // }
178 ProgramBuilder b;
179 auto* S = b.Structure(
180 "S", {
181 b.Member("m", b.ty.mat2x2<f32>(),
182 {
183 b.create<ast::StructMemberOffsetAttribute>(16),
184 b.create<ast::StrideAttribute>(8),
185 b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
186 }),
187 });
188 b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform, b.GroupAndBinding(0, 0));
189 b.Func("f", {}, b.ty.void_(),
190 {
191 b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
192 },
193 {
194 b.Stage(ast::PipelineStage::kCompute),
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000195 b.WorkgroupSize(1_i),
dan sinclair41e4d9a2022-05-01 14:40:55 +0000196 });
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000197
dan sinclair41e4d9a2022-05-01 14:40:55 +0000198 auto* expect = R"(
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000199struct S {
200 @size(16)
James Price3b671cb2022-03-28 14:31:22 +0000201 padding : u32,
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000202 @stride(8) @internal(disable_validation__ignore_stride)
James Price3b671cb2022-03-28 14:31:22 +0000203 m : mat2x2<f32>,
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000204}
205
206@group(0) @binding(0) var<uniform> s : S;
207
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000208@stage(compute) @workgroup_size(1i)
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000209fn f() {
210 let x : mat2x2<f32> = s.m;
211}
212)";
213
dan sinclair41e4d9a2022-05-01 14:40:55 +0000214 auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000215
dan sinclair41e4d9a2022-05-01 14:40:55 +0000216 EXPECT_EQ(expect, str(got));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000217}
218
219TEST_F(DecomposeStridedMatrixTest, ReadStorageMatrix) {
dan sinclair41e4d9a2022-05-01 14:40:55 +0000220 // struct S {
221 // @offset(8) @stride(32)
222 // @internal(ignore_stride_attribute)
223 // m : mat2x2<f32>,
224 // };
225 // @group(0) @binding(0) var<storage, read_write> s : S;
226 //
227 // @stage(compute) @workgroup_size(1)
228 // fn f() {
229 // let x : mat2x2<f32> = s.m;
230 // }
231 ProgramBuilder b;
232 auto* S = b.Structure(
233 "S", {
234 b.Member("m", b.ty.mat2x2<f32>(),
235 {
236 b.create<ast::StructMemberOffsetAttribute>(8),
237 b.create<ast::StrideAttribute>(32),
238 b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
239 }),
240 });
241 b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, ast::Access::kReadWrite,
242 b.GroupAndBinding(0, 0));
243 b.Func("f", {}, b.ty.void_(),
244 {
245 b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
246 },
247 {
248 b.Stage(ast::PipelineStage::kCompute),
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000249 b.WorkgroupSize(1_i),
dan sinclair41e4d9a2022-05-01 14:40:55 +0000250 });
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000251
dan sinclair41e4d9a2022-05-01 14:40:55 +0000252 auto* expect = R"(
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000253struct S {
254 @size(8)
James Price3b671cb2022-03-28 14:31:22 +0000255 padding : u32,
256 m : @stride(32) array<vec2<f32>, 2u>,
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000257}
258
259@group(0) @binding(0) var<storage, read_write> s : S;
260
261fn arr_to_mat2x2_stride_32(arr : @stride(32) array<vec2<f32>, 2u>) -> mat2x2<f32> {
262 return mat2x2<f32>(arr[0u], arr[1u]);
263}
264
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000265@stage(compute) @workgroup_size(1i)
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000266fn f() {
267 let x : mat2x2<f32> = arr_to_mat2x2_stride_32(s.m);
268}
269)";
270
dan sinclair41e4d9a2022-05-01 14:40:55 +0000271 auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000272
dan sinclair41e4d9a2022-05-01 14:40:55 +0000273 EXPECT_EQ(expect, str(got));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000274}
275
276TEST_F(DecomposeStridedMatrixTest, ReadStorageColumn) {
dan sinclair41e4d9a2022-05-01 14:40:55 +0000277 // struct S {
278 // @offset(16) @stride(32)
279 // @internal(ignore_stride_attribute)
280 // m : mat2x2<f32>,
281 // };
282 // @group(0) @binding(0) var<storage, read_write> s : S;
283 //
284 // @stage(compute) @workgroup_size(1)
285 // fn f() {
286 // let x : vec2<f32> = s.m[1];
287 // }
288 ProgramBuilder b;
289 auto* S = b.Structure(
290 "S", {
291 b.Member("m", b.ty.mat2x2<f32>(),
292 {
293 b.create<ast::StructMemberOffsetAttribute>(16),
294 b.create<ast::StrideAttribute>(32),
295 b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
296 }),
297 });
298 b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, ast::Access::kReadWrite,
299 b.GroupAndBinding(0, 0));
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000300 b.Func(
301 "f", {}, b.ty.void_(),
302 {
303 b.Decl(b.Let("x", b.ty.vec2<f32>(), b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i))),
304 },
305 {
306 b.Stage(ast::PipelineStage::kCompute),
307 b.WorkgroupSize(1_i),
308 });
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000309
dan sinclair41e4d9a2022-05-01 14:40:55 +0000310 auto* expect = R"(
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000311struct S {
312 @size(16)
James Price3b671cb2022-03-28 14:31:22 +0000313 padding : u32,
314 m : @stride(32) array<vec2<f32>, 2u>,
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000315}
316
317@group(0) @binding(0) var<storage, read_write> s : S;
318
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000319@stage(compute) @workgroup_size(1i)
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000320fn f() {
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000321 let x : vec2<f32> = s.m[1i];
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000322}
323)";
324
dan sinclair41e4d9a2022-05-01 14:40:55 +0000325 auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000326
dan sinclair41e4d9a2022-05-01 14:40:55 +0000327 EXPECT_EQ(expect, str(got));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000328}
329
330TEST_F(DecomposeStridedMatrixTest, WriteStorageMatrix) {
dan sinclair41e4d9a2022-05-01 14:40:55 +0000331 // struct S {
332 // @offset(8) @stride(32)
333 // @internal(ignore_stride_attribute)
334 // m : mat2x2<f32>,
335 // };
336 // @group(0) @binding(0) var<storage, read_write> s : S;
337 //
338 // @stage(compute) @workgroup_size(1)
339 // fn f() {
340 // s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
341 // }
342 ProgramBuilder b;
343 auto* S = b.Structure(
344 "S", {
345 b.Member("m", b.ty.mat2x2<f32>(),
346 {
347 b.create<ast::StructMemberOffsetAttribute>(8),
348 b.create<ast::StrideAttribute>(32),
349 b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
350 }),
351 });
352 b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, ast::Access::kReadWrite,
353 b.GroupAndBinding(0, 0));
354 b.Func("f", {}, b.ty.void_(),
355 {
356 b.Assign(b.MemberAccessor("s", "m"),
Ben Clayton0a3cda92022-05-10 17:30:15 +0000357 b.mat2x2<f32>(b.vec2<f32>(1_f, 2_f), b.vec2<f32>(3_f, 4_f))),
dan sinclair41e4d9a2022-05-01 14:40:55 +0000358 },
359 {
360 b.Stage(ast::PipelineStage::kCompute),
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000361 b.WorkgroupSize(1_i),
dan sinclair41e4d9a2022-05-01 14:40:55 +0000362 });
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000363
dan sinclair41e4d9a2022-05-01 14:40:55 +0000364 auto* expect = R"(
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000365struct S {
366 @size(8)
James Price3b671cb2022-03-28 14:31:22 +0000367 padding : u32,
368 m : @stride(32) array<vec2<f32>, 2u>,
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000369}
370
371@group(0) @binding(0) var<storage, read_write> s : S;
372
373fn mat2x2_stride_32_to_arr(m : mat2x2<f32>) -> @stride(32) array<vec2<f32>, 2u> {
374 return @stride(32) array<vec2<f32>, 2u>(m[0u], m[1u]);
375}
376
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000377@stage(compute) @workgroup_size(1i)
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000378fn f() {
379 s.m = mat2x2_stride_32_to_arr(mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0)));
380}
381)";
382
dan sinclair41e4d9a2022-05-01 14:40:55 +0000383 auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000384
dan sinclair41e4d9a2022-05-01 14:40:55 +0000385 EXPECT_EQ(expect, str(got));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000386}
387
388TEST_F(DecomposeStridedMatrixTest, WriteStorageColumn) {
dan sinclair41e4d9a2022-05-01 14:40:55 +0000389 // struct S {
390 // @offset(8) @stride(32)
391 // @internal(ignore_stride_attribute)
392 // m : mat2x2<f32>,
393 // };
394 // @group(0) @binding(0) var<storage, read_write> s : S;
395 //
396 // @stage(compute) @workgroup_size(1)
397 // fn f() {
398 // s.m[1] = vec2<f32>(1.0, 2.0);
399 // }
400 ProgramBuilder b;
401 auto* S = b.Structure(
402 "S", {
403 b.Member("m", b.ty.mat2x2<f32>(),
404 {
405 b.create<ast::StructMemberOffsetAttribute>(8),
406 b.create<ast::StrideAttribute>(32),
407 b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
408 }),
409 });
410 b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, ast::Access::kReadWrite,
411 b.GroupAndBinding(0, 0));
412 b.Func("f", {}, b.ty.void_(),
413 {
Ben Clayton0a3cda92022-05-10 17:30:15 +0000414 b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i), b.vec2<f32>(1_f, 2_f)),
dan sinclair41e4d9a2022-05-01 14:40:55 +0000415 },
416 {
417 b.Stage(ast::PipelineStage::kCompute),
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000418 b.WorkgroupSize(1_i),
dan sinclair41e4d9a2022-05-01 14:40:55 +0000419 });
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000420
dan sinclair41e4d9a2022-05-01 14:40:55 +0000421 auto* expect = R"(
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000422struct S {
423 @size(8)
James Price3b671cb2022-03-28 14:31:22 +0000424 padding : u32,
425 m : @stride(32) array<vec2<f32>, 2u>,
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000426}
427
428@group(0) @binding(0) var<storage, read_write> s : S;
429
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000430@stage(compute) @workgroup_size(1i)
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000431fn f() {
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000432 s.m[1i] = vec2<f32>(1.0, 2.0);
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000433}
434)";
435
dan sinclair41e4d9a2022-05-01 14:40:55 +0000436 auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000437
dan sinclair41e4d9a2022-05-01 14:40:55 +0000438 EXPECT_EQ(expect, str(got));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000439}
440
441TEST_F(DecomposeStridedMatrixTest, ReadWriteViaPointerLets) {
dan sinclair41e4d9a2022-05-01 14:40:55 +0000442 // struct S {
443 // @offset(8) @stride(32)
444 // @internal(ignore_stride_attribute)
445 // m : mat2x2<f32>,
446 // };
447 // @group(0) @binding(0) var<storage, read_write> s : S;
448 //
449 // @stage(compute) @workgroup_size(1)
450 // fn f() {
451 // let a = &s.m;
452 // let b = &*&*(a);
453 // let x = *b;
454 // let y = (*b)[1];
455 // let z = x[1];
456 // (*b) = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
457 // (*b)[1] = vec2<f32>(5.0, 6.0);
458 // }
459 ProgramBuilder b;
460 auto* S = b.Structure(
461 "S", {
462 b.Member("m", b.ty.mat2x2<f32>(),
463 {
464 b.create<ast::StructMemberOffsetAttribute>(8),
465 b.create<ast::StrideAttribute>(32),
466 b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
467 }),
468 });
469 b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, ast::Access::kReadWrite,
470 b.GroupAndBinding(0, 0));
Ben Clayton0a3cda92022-05-10 17:30:15 +0000471 b.Func("f", {}, b.ty.void_(),
472 {
473 b.Decl(b.Let("a", nullptr, b.AddressOf(b.MemberAccessor("s", "m")))),
474 b.Decl(b.Let("b", nullptr, b.AddressOf(b.Deref(b.AddressOf(b.Deref("a")))))),
475 b.Decl(b.Let("x", nullptr, b.Deref("b"))),
476 b.Decl(b.Let("y", nullptr, b.IndexAccessor(b.Deref("b"), 1_i))),
477 b.Decl(b.Let("z", nullptr, b.IndexAccessor("x", 1_i))),
478 b.Assign(b.Deref("b"), b.mat2x2<f32>(b.vec2<f32>(1_f, 2_f), b.vec2<f32>(3_f, 4_f))),
479 b.Assign(b.IndexAccessor(b.Deref("b"), 1_i), b.vec2<f32>(5_f, 6_f)),
480 },
481 {
482 b.Stage(ast::PipelineStage::kCompute),
483 b.WorkgroupSize(1_i),
484 });
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000485
dan sinclair41e4d9a2022-05-01 14:40:55 +0000486 auto* expect = R"(
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000487struct S {
488 @size(8)
James Price3b671cb2022-03-28 14:31:22 +0000489 padding : u32,
490 m : @stride(32) array<vec2<f32>, 2u>,
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000491}
492
493@group(0) @binding(0) var<storage, read_write> s : S;
494
495fn arr_to_mat2x2_stride_32(arr : @stride(32) array<vec2<f32>, 2u>) -> mat2x2<f32> {
496 return mat2x2<f32>(arr[0u], arr[1u]);
497}
498
499fn mat2x2_stride_32_to_arr(m : mat2x2<f32>) -> @stride(32) array<vec2<f32>, 2u> {
500 return @stride(32) array<vec2<f32>, 2u>(m[0u], m[1u]);
501}
502
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000503@stage(compute) @workgroup_size(1i)
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000504fn f() {
505 let x = arr_to_mat2x2_stride_32(s.m);
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000506 let y = s.m[1i];
507 let z = x[1i];
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000508 s.m = mat2x2_stride_32_to_arr(mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0)));
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000509 s.m[1i] = vec2<f32>(5.0, 6.0);
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000510}
511)";
512
dan sinclair41e4d9a2022-05-01 14:40:55 +0000513 auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000514
dan sinclair41e4d9a2022-05-01 14:40:55 +0000515 EXPECT_EQ(expect, str(got));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000516}
517
518TEST_F(DecomposeStridedMatrixTest, ReadPrivateMatrix) {
dan sinclair41e4d9a2022-05-01 14:40:55 +0000519 // struct S {
520 // @offset(8) @stride(32)
521 // @internal(ignore_stride_attribute)
522 // m : mat2x2<f32>,
523 // };
524 // var<private> s : S;
525 //
526 // @stage(compute) @workgroup_size(1)
527 // fn f() {
528 // let x : mat2x2<f32> = s.m;
529 // }
530 ProgramBuilder b;
531 auto* S = b.Structure(
532 "S", {
533 b.Member("m", b.ty.mat2x2<f32>(),
534 {
535 b.create<ast::StructMemberOffsetAttribute>(8),
536 b.create<ast::StrideAttribute>(32),
537 b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
538 }),
539 });
540 b.Global("s", b.ty.Of(S), ast::StorageClass::kPrivate);
541 b.Func("f", {}, b.ty.void_(),
542 {
543 b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
544 },
545 {
546 b.Stage(ast::PipelineStage::kCompute),
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000547 b.WorkgroupSize(1_i),
dan sinclair41e4d9a2022-05-01 14:40:55 +0000548 });
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000549
dan sinclair41e4d9a2022-05-01 14:40:55 +0000550 auto* expect = R"(
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000551struct S {
552 @size(8)
James Price3b671cb2022-03-28 14:31:22 +0000553 padding : u32,
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000554 @stride(32) @internal(disable_validation__ignore_stride)
James Price3b671cb2022-03-28 14:31:22 +0000555 m : mat2x2<f32>,
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000556}
557
558var<private> s : S;
559
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000560@stage(compute) @workgroup_size(1i)
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000561fn f() {
562 let x : mat2x2<f32> = s.m;
563}
564)";
565
dan sinclair41e4d9a2022-05-01 14:40:55 +0000566 auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000567
dan sinclair41e4d9a2022-05-01 14:40:55 +0000568 EXPECT_EQ(expect, str(got));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000569}
570
571TEST_F(DecomposeStridedMatrixTest, WritePrivateMatrix) {
dan sinclair41e4d9a2022-05-01 14:40:55 +0000572 // struct S {
573 // @offset(8) @stride(32)
574 // @internal(ignore_stride_attribute)
575 // m : mat2x2<f32>,
576 // };
577 // var<private> s : S;
578 //
579 // @stage(compute) @workgroup_size(1)
580 // fn f() {
581 // s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
582 // }
583 ProgramBuilder b;
584 auto* S = b.Structure(
585 "S", {
586 b.Member("m", b.ty.mat2x2<f32>(),
587 {
588 b.create<ast::StructMemberOffsetAttribute>(8),
589 b.create<ast::StrideAttribute>(32),
590 b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
591 }),
592 });
593 b.Global("s", b.ty.Of(S), ast::StorageClass::kPrivate);
594 b.Func("f", {}, b.ty.void_(),
595 {
596 b.Assign(b.MemberAccessor("s", "m"),
Ben Clayton0a3cda92022-05-10 17:30:15 +0000597 b.mat2x2<f32>(b.vec2<f32>(1_f, 2_f), b.vec2<f32>(3_f, 4_f))),
dan sinclair41e4d9a2022-05-01 14:40:55 +0000598 },
599 {
600 b.Stage(ast::PipelineStage::kCompute),
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000601 b.WorkgroupSize(1_i),
dan sinclair41e4d9a2022-05-01 14:40:55 +0000602 });
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000603
dan sinclair41e4d9a2022-05-01 14:40:55 +0000604 auto* expect = R"(
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000605struct S {
606 @size(8)
James Price3b671cb2022-03-28 14:31:22 +0000607 padding : u32,
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000608 @stride(32) @internal(disable_validation__ignore_stride)
James Price3b671cb2022-03-28 14:31:22 +0000609 m : mat2x2<f32>,
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000610}
611
612var<private> s : S;
613
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000614@stage(compute) @workgroup_size(1i)
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000615fn f() {
616 s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
617}
618)";
619
dan sinclair41e4d9a2022-05-01 14:40:55 +0000620 auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(Program(std::move(b)));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000621
dan sinclair41e4d9a2022-05-01 14:40:55 +0000622 EXPECT_EQ(expect, str(got));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000623}
624
625} // namespace
dan sinclairb5599d32022-04-07 16:55:14 +0000626} // namespace tint::transform