blob: f2bb2d7247080d81d9264464320b8956dcc5da6b [file] [log] [blame]
dan sinclair2a599012020-06-23 17:48:40 +00001// Copyright 2020 The Tint Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
dan sinclair5f812622020-09-22 14:53:03 +000015#include "src/ast/stage_decoration.h"
dan sinclair196e0972020-11-13 18:13:24 +000016#include "src/writer/msl/test_helper.h"
dan sinclair2a599012020-06-23 17:48:40 +000017
18namespace tint {
19namespace writer {
20namespace msl {
21namespace {
22
dan sinclair196e0972020-11-13 18:13:24 +000023using MslGeneratorImplTest = TestHelper;
dan sinclair2a599012020-06-23 17:48:40 +000024
dan sinclair2dbe9aa2020-09-21 15:16:20 +000025TEST_F(MslGeneratorImplTest, Generate) {
Ben Clayton42d1e092021-02-02 14:29:15 +000026 Func("my_func", ast::VariableList{}, ty.void_(), ast::StatementList{},
James Price95d40772021-03-11 17:39:32 +000027 ast::DecorationList{
Ben Clayton43073d82021-04-22 13:50:53 +000028 Stage(ast::PipelineStage::kCompute),
Sarahe6cb51e2021-06-29 18:39:44 +000029 WorkgroupSize(1),
Ben Clayton42d1e092021-02-02 14:29:15 +000030 });
dan sinclair2a599012020-06-23 17:48:40 +000031
Ben Claytonf12054e2021-01-21 16:15:00 +000032 GeneratorImpl& gen = Build();
33
dan sinclair196e0972020-11-13 18:13:24 +000034 ASSERT_TRUE(gen.Generate()) << gen.error();
35 EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
dan sinclair2a599012020-06-23 17:48:40 +000036
dan sinclair2aaf7b52021-02-25 21:17:47 +000037using namespace metal;
dan sinclair987376c2021-01-12 04:34:53 +000038kernel void my_func() {
Ben Clayton7c598092021-01-13 16:34:51 +000039 return;
dan sinclair2a599012020-06-23 17:48:40 +000040}
dan sinclair2dbe9aa2020-09-21 15:16:20 +000041
dan sinclair2a599012020-06-23 17:48:40 +000042)");
43}
44
dan sinclair7caf6e52020-07-15 20:51:16 +000045struct MslBuiltinData {
46 ast::Builtin builtin;
47 const char* attribute_name;
48};
49inline std::ostream& operator<<(std::ostream& out, MslBuiltinData data) {
50 out << data.builtin;
51 return out;
52}
dan sinclair196e0972020-11-13 18:13:24 +000053using MslBuiltinConversionTest = TestParamHelper<MslBuiltinData>;
dan sinclair7caf6e52020-07-15 20:51:16 +000054TEST_P(MslBuiltinConversionTest, Emit) {
55 auto params = GetParam();
56
Ben Claytonf12054e2021-01-21 16:15:00 +000057 GeneratorImpl& gen = Build();
58
dan sinclair196e0972020-11-13 18:13:24 +000059 EXPECT_EQ(gen.builtin_to_attribute(params.builtin),
dan sinclair7caf6e52020-07-15 20:51:16 +000060 std::string(params.attribute_name));
61}
62INSTANTIATE_TEST_SUITE_P(
63 MslGeneratorImplTest,
64 MslBuiltinConversionTest,
65 testing::Values(MslBuiltinData{ast::Builtin::kPosition, "position"},
dan sinclaird7335fa2021-01-18 15:51:13 +000066 MslBuiltinData{ast::Builtin::kVertexIndex, "vertex_id"},
67 MslBuiltinData{ast::Builtin::kInstanceIndex, "instance_id"},
dan sinclair7caf6e52020-07-15 20:51:16 +000068 MslBuiltinData{ast::Builtin::kFrontFacing, "front_facing"},
dan sinclair7caf6e52020-07-15 20:51:16 +000069 MslBuiltinData{ast::Builtin::kFragDepth, "depth(any)"},
dan sinclair7caf6e52020-07-15 20:51:16 +000070 MslBuiltinData{ast::Builtin::kLocalInvocationId,
71 "thread_position_in_threadgroup"},
dan sinclaird7335fa2021-01-18 15:51:13 +000072 MslBuiltinData{ast::Builtin::kLocalInvocationIndex,
dan sinclair7caf6e52020-07-15 20:51:16 +000073 "thread_index_in_threadgroup"},
74 MslBuiltinData{ast::Builtin::kGlobalInvocationId,
James Price2b5acac2021-02-09 19:13:25 +000075 "thread_position_in_grid"},
James Price395b4882021-04-16 19:57:34 +000076 MslBuiltinData{ast::Builtin::kWorkgroupId,
77 "threadgroup_position_in_grid"},
James Price922fce72021-09-13 17:11:58 +000078 MslBuiltinData{ast::Builtin::kNumWorkgroups,
79 "threadgroups_per_grid"},
James Pricee7dab3c2021-02-16 18:21:41 +000080 MslBuiltinData{ast::Builtin::kSampleIndex, "sample_id"},
James Price11e172a2021-08-05 16:21:59 +000081 MslBuiltinData{ast::Builtin::kSampleMask, "sample_mask"},
82 MslBuiltinData{ast::Builtin::kPointSize, "point_size"}));
dan sinclair7caf6e52020-07-15 20:51:16 +000083
James Price2c2aa2a2021-07-12 16:11:41 +000084TEST_F(MslGeneratorImplTest, HasInvariantAttribute_True) {
85 auto* out = Structure(
86 "Out", {Member("pos", ty.vec4<f32>(),
87 {Builtin(ast::Builtin::kPosition), Invariant()})});
88 Func("vert_main", ast::VariableList{}, ty.Of(out),
89 {Return(Construct(ty.Of(out)))}, {Stage(ast::PipelineStage::kVertex)});
90
91 GeneratorImpl& gen = Build();
92
93 ASSERT_TRUE(gen.Generate()) << gen.error();
94 EXPECT_TRUE(gen.HasInvariant());
95 EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
96
97using namespace metal;
98struct Out {
99 float4 pos [[position]] [[invariant]];
100};
101
102vertex Out vert_main() {
103 return {};
104}
105
106)");
107}
108
109TEST_F(MslGeneratorImplTest, HasInvariantAttribute_False) {
110 auto* out = Structure("Out", {Member("pos", ty.vec4<f32>(),
111 {Builtin(ast::Builtin::kPosition)})});
112 Func("vert_main", ast::VariableList{}, ty.Of(out),
113 {Return(Construct(ty.Of(out)))}, {Stage(ast::PipelineStage::kVertex)});
114
115 GeneratorImpl& gen = Build();
116
117 ASSERT_TRUE(gen.Generate()) << gen.error();
118 EXPECT_FALSE(gen.HasInvariant());
119 EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
120
121using namespace metal;
122struct Out {
123 float4 pos [[position]];
124};
125
126vertex Out vert_main() {
127 return {};
128}
129
130)");
131}
132
James Priceacaecab2021-09-13 19:56:01 +0000133TEST_F(MslGeneratorImplTest, WorkgroupMatrix) {
134 Global("m", ty.mat2x2<f32>(), ast::StorageClass::kWorkgroup);
135 Func("comp_main", ast::VariableList{}, ty.void_(),
136 {Decl(Const("x", nullptr, Expr("m")))},
137 {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
138
139 GeneratorImpl& gen = SanitizeAndBuild();
140
141 ASSERT_TRUE(gen.Generate()) << gen.error();
142 EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
143
144using namespace metal;
James Price1ca6fba2021-09-29 18:56:17 +0000145struct tint_symbol_3 {
146 float2x2 m;
147};
148
James Priceacaecab2021-09-13 19:56:01 +0000149void comp_main_inner(uint local_invocation_index, threadgroup float2x2* const tint_symbol) {
150 {
151 *(tint_symbol) = float2x2();
152 }
153 threadgroup_barrier(mem_flags::mem_threadgroup);
154 float2x2 const x = *(tint_symbol);
155}
156
James Price1ca6fba2021-09-29 18:56:17 +0000157kernel void comp_main(threadgroup tint_symbol_3* tint_symbol_2 [[threadgroup(0)]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
158 comp_main_inner(local_invocation_index, &((*(tint_symbol_2)).m));
James Priceacaecab2021-09-13 19:56:01 +0000159 return;
160}
161
162)");
163
164 auto allocations = gen.DynamicWorkgroupAllocations();
165 ASSERT_TRUE(allocations.count("comp_main"));
166 ASSERT_EQ(allocations["comp_main"].size(), 1u);
167 EXPECT_EQ(allocations["comp_main"][0], 2u * 2u * sizeof(float));
168}
169
170TEST_F(MslGeneratorImplTest, WorkgroupMatrixInArray) {
171 Global("m", ty.array(ty.mat2x2<f32>(), 4), ast::StorageClass::kWorkgroup);
172 Func("comp_main", ast::VariableList{}, ty.void_(),
173 {Decl(Const("x", nullptr, Expr("m")))},
174 {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
175
176 GeneratorImpl& gen = SanitizeAndBuild();
177
178 ASSERT_TRUE(gen.Generate()) << gen.error();
179 EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
180
181using namespace metal;
182struct tint_array_wrapper {
183 float2x2 arr[4];
184};
James Price1ca6fba2021-09-29 18:56:17 +0000185struct tint_symbol_3 {
186 tint_array_wrapper m;
187};
James Priceacaecab2021-09-13 19:56:01 +0000188
189void comp_main_inner(uint local_invocation_index, threadgroup tint_array_wrapper* const tint_symbol) {
190 for(uint idx = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
191 uint const i = idx;
192 (*(tint_symbol)).arr[i] = float2x2();
193 }
194 threadgroup_barrier(mem_flags::mem_threadgroup);
195 tint_array_wrapper const x = *(tint_symbol);
196}
197
James Price1ca6fba2021-09-29 18:56:17 +0000198kernel void comp_main(threadgroup tint_symbol_3* tint_symbol_2 [[threadgroup(0)]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
199 comp_main_inner(local_invocation_index, &((*(tint_symbol_2)).m));
James Priceacaecab2021-09-13 19:56:01 +0000200 return;
201}
202
203)");
204
205 auto allocations = gen.DynamicWorkgroupAllocations();
206 ASSERT_TRUE(allocations.count("comp_main"));
207 ASSERT_EQ(allocations["comp_main"].size(), 1u);
208 EXPECT_EQ(allocations["comp_main"][0], 4u * 2u * 2u * sizeof(float));
209}
210
211TEST_F(MslGeneratorImplTest, WorkgroupMatrixInStruct) {
212 Structure("S1", {
213 Member("m1", ty.mat2x2<f32>()),
214 Member("m2", ty.mat4x4<f32>()),
215 });
216 Structure("S2", {
217 Member("s", ty.type_name("S1")),
218 });
219 Global("s", ty.type_name("S2"), ast::StorageClass::kWorkgroup);
220 Func("comp_main", ast::VariableList{}, ty.void_(),
221 {Decl(Const("x", nullptr, Expr("s")))},
222 {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
223
224 GeneratorImpl& gen = SanitizeAndBuild();
225
226 ASSERT_TRUE(gen.Generate()) << gen.error();
227 EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
228
229using namespace metal;
230struct S1 {
231 float2x2 m1;
232 float4x4 m2;
233};
234struct S2 {
235 S1 s;
236};
James Price1ca6fba2021-09-29 18:56:17 +0000237struct tint_symbol_4 {
238 S2 s;
239};
James Priceacaecab2021-09-13 19:56:01 +0000240
241void comp_main_inner(uint local_invocation_index, threadgroup S2* const tint_symbol_1) {
242 {
243 S2 const tint_symbol = {};
244 *(tint_symbol_1) = tint_symbol;
245 }
246 threadgroup_barrier(mem_flags::mem_threadgroup);
247 S2 const x = *(tint_symbol_1);
248}
249
James Price1ca6fba2021-09-29 18:56:17 +0000250kernel void comp_main(threadgroup tint_symbol_4* tint_symbol_3 [[threadgroup(0)]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
251 comp_main_inner(local_invocation_index, &((*(tint_symbol_3)).s));
James Priceacaecab2021-09-13 19:56:01 +0000252 return;
253}
254
255)");
256
257 auto allocations = gen.DynamicWorkgroupAllocations();
258 ASSERT_TRUE(allocations.count("comp_main"));
259 ASSERT_EQ(allocations["comp_main"].size(), 1u);
260 EXPECT_EQ(allocations["comp_main"][0],
261 (2 * 2 * sizeof(float)) + (4u * 4u * sizeof(float)));
262}
263
264TEST_F(MslGeneratorImplTest, WorkgroupMatrix_Multiples) {
265 Global("m1", ty.mat2x2<f32>(), ast::StorageClass::kWorkgroup);
266 Global("m2", ty.mat2x3<f32>(), ast::StorageClass::kWorkgroup);
267 Global("m3", ty.mat2x4<f32>(), ast::StorageClass::kWorkgroup);
268 Global("m4", ty.mat3x2<f32>(), ast::StorageClass::kWorkgroup);
269 Global("m5", ty.mat3x3<f32>(), ast::StorageClass::kWorkgroup);
270 Global("m6", ty.mat3x4<f32>(), ast::StorageClass::kWorkgroup);
271 Global("m7", ty.mat4x2<f32>(), ast::StorageClass::kWorkgroup);
272 Global("m8", ty.mat4x3<f32>(), ast::StorageClass::kWorkgroup);
273 Global("m9", ty.mat4x4<f32>(), ast::StorageClass::kWorkgroup);
274 Func("main1", ast::VariableList{}, ty.void_(),
275 {
276 Decl(Const("a1", nullptr, Expr("m1"))),
277 Decl(Const("a2", nullptr, Expr("m2"))),
278 Decl(Const("a3", nullptr, Expr("m3"))),
279 },
280 {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
281 Func("main2", ast::VariableList{}, ty.void_(),
282 {
283 Decl(Const("a1", nullptr, Expr("m4"))),
284 Decl(Const("a2", nullptr, Expr("m5"))),
285 Decl(Const("a3", nullptr, Expr("m6"))),
286 },
287 {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
288 Func("main3", ast::VariableList{}, ty.void_(),
289 {
290 Decl(Const("a1", nullptr, Expr("m7"))),
291 Decl(Const("a2", nullptr, Expr("m8"))),
292 Decl(Const("a3", nullptr, Expr("m9"))),
293 },
294 {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
295 Func("main4_no_usages", ast::VariableList{}, ty.void_(), {},
296 {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
297
298 GeneratorImpl& gen = SanitizeAndBuild();
299
300 ASSERT_TRUE(gen.Generate()) << gen.error();
301 EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
302
303using namespace metal;
James Price1ca6fba2021-09-29 18:56:17 +0000304struct tint_symbol_7 {
305 float2x2 m1;
306 float2x3 m2;
307 float2x4 m3;
308};
309struct tint_symbol_15 {
310 float3x2 m4;
311 float3x3 m5;
312 float3x4 m6;
313};
314struct tint_symbol_23 {
315 float4x2 m7;
316 float4x3 m8;
317 float4x4 m9;
318};
319
James Priceacaecab2021-09-13 19:56:01 +0000320void main1_inner(uint local_invocation_index, threadgroup float2x2* const tint_symbol, threadgroup float2x3* const tint_symbol_1, threadgroup float2x4* const tint_symbol_2) {
321 {
322 *(tint_symbol) = float2x2();
323 *(tint_symbol_1) = float2x3();
324 *(tint_symbol_2) = float2x4();
325 }
326 threadgroup_barrier(mem_flags::mem_threadgroup);
327 float2x2 const a1 = *(tint_symbol);
328 float2x3 const a2 = *(tint_symbol_1);
329 float2x4 const a3 = *(tint_symbol_2);
330}
331
James Price1ca6fba2021-09-29 18:56:17 +0000332kernel void main1(threadgroup tint_symbol_7* tint_symbol_4 [[threadgroup(0)]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
333 main1_inner(local_invocation_index, &((*(tint_symbol_4)).m1), &((*(tint_symbol_4)).m2), &((*(tint_symbol_4)).m3));
James Priceacaecab2021-09-13 19:56:01 +0000334 return;
335}
336
James Price1ca6fba2021-09-29 18:56:17 +0000337void main2_inner(uint local_invocation_index_1, threadgroup float3x2* const tint_symbol_8, threadgroup float3x3* const tint_symbol_9, threadgroup float3x4* const tint_symbol_10) {
James Priceacaecab2021-09-13 19:56:01 +0000338 {
James Price1ca6fba2021-09-29 18:56:17 +0000339 *(tint_symbol_8) = float3x2();
340 *(tint_symbol_9) = float3x3();
341 *(tint_symbol_10) = float3x4();
James Priceacaecab2021-09-13 19:56:01 +0000342 }
343 threadgroup_barrier(mem_flags::mem_threadgroup);
James Price1ca6fba2021-09-29 18:56:17 +0000344 float3x2 const a1 = *(tint_symbol_8);
345 float3x3 const a2 = *(tint_symbol_9);
346 float3x4 const a3 = *(tint_symbol_10);
James Priceacaecab2021-09-13 19:56:01 +0000347}
348
James Price1ca6fba2021-09-29 18:56:17 +0000349kernel void main2(threadgroup tint_symbol_15* tint_symbol_12 [[threadgroup(0)]], uint local_invocation_index_1 [[thread_index_in_threadgroup]]) {
350 main2_inner(local_invocation_index_1, &((*(tint_symbol_12)).m4), &((*(tint_symbol_12)).m5), &((*(tint_symbol_12)).m6));
James Priceacaecab2021-09-13 19:56:01 +0000351 return;
352}
353
James Price1ca6fba2021-09-29 18:56:17 +0000354void main3_inner(uint local_invocation_index_2, threadgroup float4x2* const tint_symbol_16, threadgroup float4x3* const tint_symbol_17, threadgroup float4x4* const tint_symbol_18) {
James Priceacaecab2021-09-13 19:56:01 +0000355 {
James Price1ca6fba2021-09-29 18:56:17 +0000356 *(tint_symbol_16) = float4x2();
357 *(tint_symbol_17) = float4x3();
358 *(tint_symbol_18) = float4x4();
James Priceacaecab2021-09-13 19:56:01 +0000359 }
360 threadgroup_barrier(mem_flags::mem_threadgroup);
James Price1ca6fba2021-09-29 18:56:17 +0000361 float4x2 const a1 = *(tint_symbol_16);
362 float4x3 const a2 = *(tint_symbol_17);
363 float4x4 const a3 = *(tint_symbol_18);
James Priceacaecab2021-09-13 19:56:01 +0000364}
365
James Price1ca6fba2021-09-29 18:56:17 +0000366kernel void main3(threadgroup tint_symbol_23* tint_symbol_20 [[threadgroup(0)]], uint local_invocation_index_2 [[thread_index_in_threadgroup]]) {
367 main3_inner(local_invocation_index_2, &((*(tint_symbol_20)).m7), &((*(tint_symbol_20)).m8), &((*(tint_symbol_20)).m9));
James Priceacaecab2021-09-13 19:56:01 +0000368 return;
369}
370
371kernel void main4_no_usages() {
372 return;
373}
374
375)");
376
377 auto allocations = gen.DynamicWorkgroupAllocations();
378 ASSERT_TRUE(allocations.count("main1"));
379 ASSERT_TRUE(allocations.count("main2"));
380 ASSERT_TRUE(allocations.count("main3"));
381 EXPECT_EQ(allocations.count("main4_no_usages"), 0u);
James Price1ca6fba2021-09-29 18:56:17 +0000382 ASSERT_EQ(allocations["main1"].size(), 1u);
383 EXPECT_EQ(allocations["main1"][0], 20u * sizeof(float));
384 ASSERT_EQ(allocations["main2"].size(), 1u);
385 EXPECT_EQ(allocations["main2"][0], 32u * sizeof(float));
386 ASSERT_EQ(allocations["main3"].size(), 1u);
387 EXPECT_EQ(allocations["main3"][0], 40u * sizeof(float));
James Priceacaecab2021-09-13 19:56:01 +0000388}
389
dan sinclair2a599012020-06-23 17:48:40 +0000390} // namespace
391} // namespace msl
392} // namespace writer
393} // namespace tint