blob: f41fbb7127c69ccbcdec38a0ad1ed42dc174013b [file] [log] [blame]
Antonio Maiorano268d7b82022-06-24 22:28:23 +00001// Copyright 2022 The Tint Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "src/tint/transform/spirv_atomic.h"
16
17#include <string>
18#include <unordered_map>
19#include <unordered_set>
20#include <utility>
21#include <vector>
22
23#include "src/tint/program_builder.h"
24#include "src/tint/sem/block_statement.h"
25#include "src/tint/sem/function.h"
26#include "src/tint/sem/index_accessor_expression.h"
27#include "src/tint/sem/member_accessor_expression.h"
Antonio Maiorano268d7b82022-06-24 22:28:23 +000028#include "src/tint/sem/statement.h"
Ben Clayton23946b32023-03-09 16:50:19 +000029#include "src/tint/switch.h"
dan sinclair4d56b482022-12-08 17:50:50 +000030#include "src/tint/type/reference.h"
Antonio Maiorano268d7b82022-06-24 22:28:23 +000031#include "src/tint/utils/map.h"
32#include "src/tint/utils/unique_vector.h"
33
34TINT_INSTANTIATE_TYPEINFO(tint::transform::SpirvAtomic);
35TINT_INSTANTIATE_TYPEINFO(tint::transform::SpirvAtomic::Stub);
36
37namespace tint::transform {
38
dan sinclair78f80672022-09-22 22:28:21 +000039using namespace tint::number_suffixes; // NOLINT
40
Ben Claytonc6b38142022-11-03 08:41:19 +000041/// PIMPL state for the transform
Antonio Maiorano268d7b82022-06-24 22:28:23 +000042struct SpirvAtomic::State {
43 private:
44 /// A struct that has been forked because a subset of members were made atomic.
45 struct ForkedStruct {
46 Symbol name;
47 std::unordered_set<size_t> atomic_members;
48 };
49
Ben Claytonc6b38142022-11-03 08:41:19 +000050 /// The source program
51 const Program* const src;
52 /// The target program builder
53 ProgramBuilder b;
54 /// The clone context
55 CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
Antonio Maiorano268d7b82022-06-24 22:28:23 +000056 std::unordered_map<const ast::Struct*, ForkedStruct> forked_structs;
57 std::unordered_set<const sem::Variable*> atomic_variables;
Ben Clayton3fb9a3f2023-02-04 21:20:26 +000058 utils::UniqueVector<const sem::ValueExpression*, 8> atomic_expressions;
Antonio Maiorano268d7b82022-06-24 22:28:23 +000059
60 public:
61 /// Constructor
Ben Claytonc6b38142022-11-03 08:41:19 +000062 /// @param program the source program
63 explicit State(const Program* program) : src(program) {}
Antonio Maiorano268d7b82022-06-24 22:28:23 +000064
65 /// Runs the transform
Ben Claytonc6b38142022-11-03 08:41:19 +000066 /// @returns the new program or SkipTransform if the transform is not required
67 ApplyResult Run() {
Ben Claytonc07de732022-12-06 19:41:22 +000068 bool made_changes = false;
69
Antonio Maiorano268d7b82022-06-24 22:28:23 +000070 // Look for stub functions generated by the SPIR-V reader, which are used as placeholders
71 // for atomic builtin calls.
72 for (auto* fn : ctx.src->AST().Functions()) {
73 if (auto* stub = ast::GetAttribute<Stub>(fn->attributes)) {
74 auto* sem = ctx.src->Sem().Get(fn);
75
76 for (auto* call : sem->CallSites()) {
77 // The first argument is always the atomic.
78 // The stub passes this by value, whereas the builtin wants a pointer.
79 // Take the address of the atomic argument.
80 auto& args = call->Declaration()->args;
81 auto out_args = ctx.Clone(args);
82 out_args[0] = b.AddressOf(out_args[0]);
83
84 // Replace all callsites of this stub to a call to the real builtin
dan sinclair9543f742023-03-09 01:20:16 +000085 if (stub->builtin == builtin::Function::kAtomicCompareExchangeWeak) {
Antonio Maiorano268d7b82022-06-24 22:28:23 +000086 // atomicCompareExchangeWeak returns a struct, so insert a call to it above
87 // the current statement, and replace the current call with the struct's
88 // `old_value` member.
89 auto* block = call->Stmt()->Block()->Declaration();
90 auto old_value = b.Symbols().New("old_value");
91 auto old_value_decl = b.Decl(b.Let(
dan sinclair9543f742023-03-09 01:20:16 +000092 old_value, b.MemberAccessor(
93 b.Call(builtin::str(stub->builtin), std::move(out_args)),
94 "old_value")));
Antonio Maiorano268d7b82022-06-24 22:28:23 +000095 ctx.InsertBefore(block->statements, call->Stmt()->Declaration(),
96 old_value_decl);
97 ctx.Replace(call->Declaration(), b.Expr(old_value));
98 } else {
99 ctx.Replace(call->Declaration(),
dan sinclair9543f742023-03-09 01:20:16 +0000100 b.Call(builtin::str(stub->builtin), std::move(out_args)));
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000101 }
102
James Pricea7cd3ae2022-11-09 12:16:56 +0000103 // Keep track of this expression. We'll need to modify the root identifier /
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000104 // structure to be atomic.
Ben Clayton0b4a2f12023-02-05 22:59:40 +0000105 atomic_expressions.Add(ctx.src->Sem().GetVal(args[0]));
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000106 }
107
108 // Remove the stub from the output program
109 ctx.Remove(ctx.src->AST().GlobalDeclarations(), fn);
Ben Claytonc07de732022-12-06 19:41:22 +0000110 made_changes = true;
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000111 }
112 }
113
Ben Claytonc07de732022-12-06 19:41:22 +0000114 if (!made_changes) {
Ben Claytonc6b38142022-11-03 08:41:19 +0000115 return SkipTransform;
116 }
117
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000118 // Transform all variables and structure members that were used in atomic operations as
119 // atomic types. This propagates up originating expression chains.
120 ProcessAtomicExpressions();
121
122 // If we need to change structure members, then fork them.
123 if (!forked_structs.empty()) {
124 ctx.ReplaceAll([&](const ast::Struct* str) {
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000125 // Is `str` a structure we need to fork?
126 if (auto it = forked_structs.find(str); it != forked_structs.end()) {
127 const auto& forked = it->second;
128
129 // Re-create the structure swapping in the atomic-flavoured members
Ben Clayton783b1692022-08-02 17:03:35 +0000130 utils::Vector<const ast::StructMember*, 8> members;
131 members.Reserve(str->members.Length());
132 for (size_t i = 0; i < str->members.Length(); i++) {
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000133 auto* member = str->members[i];
134 if (forked.atomic_members.count(i)) {
Ben Clayton971318f2023-02-14 13:52:43 +0000135 auto type = AtomicTypeFor(ctx.src->Sem().Get(member)->Type());
Ben Clayton199440e2023-02-09 10:34:14 +0000136 auto name = ctx.src->Symbols().NameFor(member->name->symbol);
Ben Clayton783b1692022-08-02 17:03:35 +0000137 members.Push(b.Member(name, type, ctx.Clone(member->attributes)));
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000138 } else {
Ben Clayton783b1692022-08-02 17:03:35 +0000139 members.Push(ctx.Clone(member));
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000140 }
141 }
142 b.Structure(forked.name, std::move(members));
143 }
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000144 return nullptr;
145 });
146 }
147
Antonio Maioranoc0d51f12022-06-29 22:21:31 +0000148 // Replace assignments and decls from atomic variables with atomicLoads, and assignments to
149 // atomic variables with atomicStores.
150 ReplaceLoadsAndStores();
151
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000152 ctx.Clone();
Ben Claytonc6b38142022-11-03 08:41:19 +0000153 return Program(std::move(b));
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000154 }
155
156 private:
157 ForkedStruct& Fork(const ast::Struct* str) {
158 auto& forked = forked_structs[str];
159 if (!forked.name.IsValid()) {
Ben Claytonb75252b2023-02-09 10:34:14 +0000160 forked.name =
161 b.Symbols().New(ctx.src->Symbols().NameFor(str->name->symbol) + "_atomic");
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000162 }
163 return forked;
164 }
165
166 void ProcessAtomicExpressions() {
Ben Claytondce63f52022-08-17 18:07:20 +0000167 for (size_t i = 0; i < atomic_expressions.Length(); i++) {
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000168 Switch(
Ben Clayton2f9a9882022-12-17 02:20:04 +0000169 atomic_expressions[i]->UnwrapLoad(), //
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000170 [&](const sem::VariableUser* user) {
171 auto* v = user->Variable()->Declaration();
172 if (v->type && atomic_variables.emplace(user->Variable()).second) {
Ben Clayton971318f2023-02-14 13:52:43 +0000173 ctx.Replace(v->type.expr, b.Expr(AtomicTypeFor(user->Variable()->Type())));
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000174 }
dan sinclair6e77b472022-10-20 13:38:28 +0000175 if (auto* ctor = user->Variable()->Initializer()) {
Ben Claytondce63f52022-08-17 18:07:20 +0000176 atomic_expressions.Add(ctor);
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000177 }
178 },
179 [&](const sem::StructMemberAccess* access) {
180 // Fork the struct (the first time) and mark member(s) that need to be made
181 // atomic.
182 auto* member = access->Member();
183 Fork(member->Struct()->Declaration()).atomic_members.emplace(member->Index());
Ben Claytondce63f52022-08-17 18:07:20 +0000184 atomic_expressions.Add(access->Object());
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000185 },
186 [&](const sem::IndexAccessorExpression* index) {
Ben Claytondce63f52022-08-17 18:07:20 +0000187 atomic_expressions.Add(index->Object());
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000188 },
Ben Clayton3fb9a3f2023-02-04 21:20:26 +0000189 [&](const sem::ValueExpression* e) {
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000190 if (auto* unary = e->Declaration()->As<ast::UnaryOpExpression>()) {
Ben Clayton0b4a2f12023-02-05 22:59:40 +0000191 atomic_expressions.Add(ctx.src->Sem().GetVal(unary->expr));
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000192 }
193 });
194 }
195 }
196
Ben Clayton971318f2023-02-14 13:52:43 +0000197 ast::Type AtomicTypeFor(const type::Type* ty) {
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000198 return Switch(
199 ty, //
dan sinclaird37ecf92022-12-08 16:39:59 +0000200 [&](const type::I32*) { return b.ty.atomic(CreateASTTypeFor(ctx, ty)); },
201 [&](const type::U32*) { return b.ty.atomic(CreateASTTypeFor(ctx, ty)); },
Ben Clayton2117f802023-02-03 14:01:43 +0000202 [&](const sem::Struct* str) { return b.ty(Fork(str->Declaration()).name); },
Ben Clayton971318f2023-02-14 13:52:43 +0000203 [&](const type::Array* arr) {
dan sinclair5f764d82022-12-08 00:32:27 +0000204 if (arr->Count()->Is<type::RuntimeArrayCount>()) {
dan sinclair78f80672022-09-22 22:28:21 +0000205 return b.ty.array(AtomicTypeFor(arr->ElemType()));
206 }
207 auto count = arr->ConstantCount();
208 if (!count) {
209 ctx.dst->Diagnostics().add_error(
210 diag::System::Transform,
211 "the SpirvAtomic transform does not currently support array counts that "
212 "use override values");
213 count = 1;
214 }
215 return b.ty.array(AtomicTypeFor(arr->ElemType()), u32(count.value()));
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000216 },
dan sinclair4d56b482022-12-08 17:50:50 +0000217 [&](const type::Pointer* ptr) {
dan sinclairff7cf212022-10-03 14:05:23 +0000218 return b.ty.pointer(AtomicTypeFor(ptr->StoreType()), ptr->AddressSpace(),
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000219 ptr->Access());
220 },
dan sinclair4d56b482022-12-08 17:50:50 +0000221 [&](const type::Reference* ref) { return AtomicTypeFor(ref->StoreType()); },
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000222 [&](Default) {
223 TINT_ICE(Transform, b.Diagnostics())
224 << "unhandled type: " << ty->FriendlyName(ctx.src->Symbols());
Ben Clayton971318f2023-02-14 13:52:43 +0000225 return ast::Type{};
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000226 });
227 }
Antonio Maioranoc0d51f12022-06-29 22:21:31 +0000228
229 void ReplaceLoadsAndStores() {
230 // Returns true if 'e' is a reference to an atomic variable or struct member
Ben Clayton3fb9a3f2023-02-04 21:20:26 +0000231 auto is_ref_to_atomic_var = [&](const sem::ValueExpression* e) {
dan sinclair4d56b482022-12-08 17:50:50 +0000232 if (tint::Is<type::Reference>(e->Type()) && e->RootIdentifier() &&
James Pricea7cd3ae2022-11-09 12:16:56 +0000233 (atomic_variables.count(e->RootIdentifier()) != 0)) {
Antonio Maioranoc0d51f12022-06-29 22:21:31 +0000234 // If it's a struct member, make sure it's one we marked as atomic
235 if (auto* ma = e->As<sem::StructMemberAccess>()) {
236 auto it = forked_structs.find(ma->Member()->Struct()->Declaration());
237 if (it != forked_structs.end()) {
238 auto& forked = it->second;
239 return forked.atomic_members.count(ma->Member()->Index()) != 0;
240 }
241 }
242 return true;
243 }
244 return false;
245 };
246
247 // Look for loads and stores via assignments and decls of atomic variables we've collected
248 // so far, and replace them with atomicLoad and atomicStore.
249 for (auto* atomic_var : atomic_variables) {
250 for (auto* vu : atomic_var->Users()) {
251 Switch(
252 vu->Stmt()->Declaration(),
253 [&](const ast::AssignmentStatement* assign) {
Ben Clayton0b4a2f12023-02-05 22:59:40 +0000254 auto* sem_lhs = ctx.src->Sem().GetVal(assign->lhs);
Antonio Maioranoc0d51f12022-06-29 22:21:31 +0000255 if (is_ref_to_atomic_var(sem_lhs)) {
256 ctx.Replace(assign, [=] {
257 auto* lhs = ctx.CloneWithoutTransform(assign->lhs);
258 auto* rhs = ctx.CloneWithoutTransform(assign->rhs);
dan sinclair9543f742023-03-09 01:20:16 +0000259 auto* call = b.Call(builtin::str(builtin::Function::kAtomicStore),
Antonio Maioranoc0d51f12022-06-29 22:21:31 +0000260 b.AddressOf(lhs), rhs);
261 return b.CallStmt(call);
262 });
263 return;
264 }
265
Ben Clayton0b4a2f12023-02-05 22:59:40 +0000266 auto sem_rhs = ctx.src->Sem().GetVal(assign->rhs);
Ben Clayton2f9a9882022-12-17 02:20:04 +0000267 if (is_ref_to_atomic_var(sem_rhs->UnwrapLoad())) {
Antonio Maioranoc0d51f12022-06-29 22:21:31 +0000268 ctx.Replace(assign->rhs, [=] {
269 auto* rhs = ctx.CloneWithoutTransform(assign->rhs);
dan sinclair9543f742023-03-09 01:20:16 +0000270 return b.Call(builtin::str(builtin::Function::kAtomicLoad),
Antonio Maioranoc0d51f12022-06-29 22:21:31 +0000271 b.AddressOf(rhs));
272 });
273 return;
274 }
275 },
276 [&](const ast::VariableDeclStatement* decl) {
277 auto* var = decl->variable;
Ben Clayton0b4a2f12023-02-05 22:59:40 +0000278 if (auto* sem_init = ctx.src->Sem().GetVal(var->initializer)) {
Ben Clayton2f9a9882022-12-17 02:20:04 +0000279 if (is_ref_to_atomic_var(sem_init->UnwrapLoad())) {
dan sinclair6e77b472022-10-20 13:38:28 +0000280 ctx.Replace(var->initializer, [=] {
281 auto* rhs = ctx.CloneWithoutTransform(var->initializer);
dan sinclair9543f742023-03-09 01:20:16 +0000282 return b.Call(builtin::str(builtin::Function::kAtomicLoad),
Antonio Maioranoc0d51f12022-06-29 22:21:31 +0000283 b.AddressOf(rhs));
284 });
285 return;
286 }
287 }
288 });
289 }
290 }
291 }
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000292};
293
294SpirvAtomic::SpirvAtomic() = default;
295SpirvAtomic::~SpirvAtomic() = default;
296
dan sinclair9543f742023-03-09 01:20:16 +0000297SpirvAtomic::Stub::Stub(ProgramID pid, ast::NodeID nid, builtin::Function b)
Ben Clayton63d0fab2023-03-06 15:43:16 +0000298 : Base(pid, nid, utils::Empty), builtin(b) {}
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000299SpirvAtomic::Stub::~Stub() = default;
300std::string SpirvAtomic::Stub::InternalName() const {
dan sinclair9543f742023-03-09 01:20:16 +0000301 return "@internal(spirv-atomic " + std::string(builtin::str(builtin)) + ")";
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000302}
303
304const SpirvAtomic::Stub* SpirvAtomic::Stub::Clone(CloneContext* ctx) const {
Ben Clayton4a92a3c2022-07-18 20:50:02 +0000305 return ctx->dst->ASTNodes().Create<SpirvAtomic::Stub>(ctx->dst->ID(),
306 ctx->dst->AllocateNodeID(), builtin);
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000307}
308
Ben Claytonc6b38142022-11-03 08:41:19 +0000309Transform::ApplyResult SpirvAtomic::Apply(const Program* src, const DataMap&, DataMap&) const {
310 return State{src}.Run();
Antonio Maiorano268d7b82022-06-24 22:28:23 +0000311}
312
313} // namespace tint::transform