blob: 743a992b444648692afe3fbcdcd2d8807804679a [file] [log] [blame]
Ryan Harrisondbc13af2022-02-21 15:19:07 +00001// Copyright 2021 The Tint Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "src/tint/transform/calculate_array_length.h"
16
17#include <unordered_map>
18#include <utility>
19
20#include "src/tint/ast/call_statement.h"
21#include "src/tint/ast/disable_validation_attribute.h"
22#include "src/tint/program_builder.h"
23#include "src/tint/sem/block_statement.h"
24#include "src/tint/sem/call.h"
25#include "src/tint/sem/function.h"
26#include "src/tint/sem/statement.h"
27#include "src/tint/sem/struct.h"
28#include "src/tint/sem/variable.h"
Ben Clayton23946b32023-03-09 16:50:19 +000029#include "src/tint/switch.h"
Ryan Harrisondbc13af2022-02-21 15:19:07 +000030#include "src/tint/transform/simplify_pointers.h"
dan sinclair4d56b482022-12-08 17:50:50 +000031#include "src/tint/type/reference.h"
Ryan Harrisondbc13af2022-02-21 15:19:07 +000032#include "src/tint/utils/hash.h"
33#include "src/tint/utils/map.h"
34
35TINT_INSTANTIATE_TYPEINFO(tint::transform::CalculateArrayLength);
dan sinclair41e4d9a2022-05-01 14:40:55 +000036TINT_INSTANTIATE_TYPEINFO(tint::transform::CalculateArrayLength::BufferSizeIntrinsic);
Ryan Harrisondbc13af2022-02-21 15:19:07 +000037
Ben Clayton0ce9ab02022-05-05 20:23:40 +000038using namespace tint::number_suffixes; // NOLINT
39
dan sinclairb5599d32022-04-07 16:55:14 +000040namespace tint::transform {
Ryan Harrisondbc13af2022-02-21 15:19:07 +000041
42namespace {
43
Ben Claytonc6b38142022-11-03 08:41:19 +000044bool ShouldRun(const Program* program) {
45 for (auto* fn : program->AST().Functions()) {
46 if (auto* sem_fn = program->Sem().Get(fn)) {
47 for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
dan sinclair9543f742023-03-09 01:20:16 +000048 if (builtin->Type() == builtin::Function::kArrayLength) {
Ben Claytonc6b38142022-11-03 08:41:19 +000049 return true;
50 }
51 }
52 }
53 }
54 return false;
55}
56
Ryan Harrisondbc13af2022-02-21 15:19:07 +000057/// ArrayUsage describes a runtime array usage.
58/// It is used as a key by the array_length_by_usage map.
59struct ArrayUsage {
dan sinclair41e4d9a2022-05-01 14:40:55 +000060 ast::BlockStatement const* const block;
61 sem::Variable const* const buffer;
62 bool operator==(const ArrayUsage& rhs) const {
63 return block == rhs.block && buffer == rhs.buffer;
Ryan Harrisondbc13af2022-02-21 15:19:07 +000064 }
dan sinclair41e4d9a2022-05-01 14:40:55 +000065 struct Hasher {
66 inline std::size_t operator()(const ArrayUsage& u) const {
67 return utils::Hash(u.block, u.buffer);
68 }
69 };
Ryan Harrisondbc13af2022-02-21 15:19:07 +000070};
71
72} // namespace
73
Ben Clayton4a92a3c2022-07-18 20:50:02 +000074CalculateArrayLength::BufferSizeIntrinsic::BufferSizeIntrinsic(ProgramID pid, ast::NodeID nid)
Ben Clayton63d0fab2023-03-06 15:43:16 +000075 : Base(pid, nid, utils::Empty) {}
Ryan Harrisondbc13af2022-02-21 15:19:07 +000076CalculateArrayLength::BufferSizeIntrinsic::~BufferSizeIntrinsic() = default;
77std::string CalculateArrayLength::BufferSizeIntrinsic::InternalName() const {
dan sinclair41e4d9a2022-05-01 14:40:55 +000078 return "intrinsic_buffer_size";
Ryan Harrisondbc13af2022-02-21 15:19:07 +000079}
80
dan sinclair41e4d9a2022-05-01 14:40:55 +000081const CalculateArrayLength::BufferSizeIntrinsic* CalculateArrayLength::BufferSizeIntrinsic::Clone(
82 CloneContext* ctx) const {
Ben Clayton4a92a3c2022-07-18 20:50:02 +000083 return ctx->dst->ASTNodes().Create<CalculateArrayLength::BufferSizeIntrinsic>(
84 ctx->dst->ID(), ctx->dst->AllocateNodeID());
Ryan Harrisondbc13af2022-02-21 15:19:07 +000085}
86
87CalculateArrayLength::CalculateArrayLength() = default;
88CalculateArrayLength::~CalculateArrayLength() = default;
89
Ben Claytonc6b38142022-11-03 08:41:19 +000090Transform::ApplyResult CalculateArrayLength::Apply(const Program* src,
91 const DataMap&,
92 DataMap&) const {
93 if (!ShouldRun(src)) {
94 return SkipTransform;
Ryan Harrisondbc13af2022-02-21 15:19:07 +000095 }
Ryan Harrisondbc13af2022-02-21 15:19:07 +000096
Ben Claytonc6b38142022-11-03 08:41:19 +000097 ProgramBuilder b;
98 CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
99 auto& sem = src->Sem();
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000100
dan sinclair41e4d9a2022-05-01 14:40:55 +0000101 // get_buffer_size_intrinsic() emits the function decorated with
102 // BufferSizeIntrinsic that is transformed by the HLSL writer into a call to
103 // [RW]ByteAddressBuffer.GetDimensions().
dan sinclair4d56b482022-12-08 17:50:50 +0000104 std::unordered_map<const type::Reference*, Symbol> buffer_size_intrinsics;
105 auto get_buffer_size_intrinsic = [&](const type::Reference* buffer_type) {
dan sinclair41e4d9a2022-05-01 14:40:55 +0000106 return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] {
Ben Claytonc6b38142022-11-03 08:41:19 +0000107 auto name = b.Sym();
Ben Clayton971318f2023-02-14 13:52:43 +0000108 auto type = CreateASTTypeFor(ctx, buffer_type);
Ben Claytonc6b38142022-11-03 08:41:19 +0000109 auto* disable_validation = b.Disable(ast::DisabledValidation::kFunctionParameter);
Ben Clayton971318f2023-02-14 13:52:43 +0000110 b.Func(
111 name,
Ben Clayton783b1692022-08-02 17:03:35 +0000112 utils::Vector{
Ben Claytonc6b38142022-11-03 08:41:19 +0000113 b.Param("buffer",
114 b.ty.pointer(type, buffer_type->AddressSpace(), buffer_type->Access()),
115 utils::Vector{disable_validation}),
dan sinclair2a651632023-02-19 04:03:55 +0000116 b.Param("result", b.ty.pointer(b.ty.u32(), builtin::AddressSpace::kFunction)),
dan sinclair41e4d9a2022-05-01 14:40:55 +0000117 },
Ben Claytonc6b38142022-11-03 08:41:19 +0000118 b.ty.void_(), nullptr,
Ben Clayton783b1692022-08-02 17:03:35 +0000119 utils::Vector{
Ben Claytonc6b38142022-11-03 08:41:19 +0000120 b.ASTNodes().Create<BufferSizeIntrinsic>(b.ID(), b.AllocateNodeID()),
Ben Clayton971318f2023-02-14 13:52:43 +0000121 });
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000122
dan sinclair41e4d9a2022-05-01 14:40:55 +0000123 return name;
124 });
125 };
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000126
dan sinclair41e4d9a2022-05-01 14:40:55 +0000127 std::unordered_map<ArrayUsage, Symbol, ArrayUsage::Hasher> array_length_by_usage;
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000128
dan sinclair41e4d9a2022-05-01 14:40:55 +0000129 // Find all the arrayLength() calls...
Ben Claytonc6b38142022-11-03 08:41:19 +0000130 for (auto* node : src->ASTNodes().Objects()) {
dan sinclair41e4d9a2022-05-01 14:40:55 +0000131 if (auto* call_expr = node->As<ast::CallExpression>()) {
Ben Claytone5a67ac2022-05-19 21:50:59 +0000132 auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>();
dan sinclair41e4d9a2022-05-01 14:40:55 +0000133 if (auto* builtin = call->Target()->As<sem::Builtin>()) {
dan sinclair9543f742023-03-09 01:20:16 +0000134 if (builtin->Type() == builtin::Function::kArrayLength) {
dan sinclair41e4d9a2022-05-01 14:40:55 +0000135 // We're dealing with an arrayLength() call
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000136
Ben Clayton4b707762022-09-09 20:42:29 +0000137 if (auto* call_stmt = call->Stmt()->Declaration()->As<ast::CallStatement>()) {
138 if (call_stmt->expr == call_expr) {
139 // arrayLength() is used as a statement.
140 // The argument expression must be side-effect free, so just drop the
141 // statement.
142 RemoveStatement(ctx, call_stmt);
143 continue;
144 }
145 }
146
Ben Clayton2032d032022-06-15 19:32:37 +0000147 // A runtime-sized array can only appear as the store type of a variable, or the
148 // last element of a structure (which cannot itself be nested). Given that we
149 // require SimplifyPointers, we can assume that the arrayLength() call has one
150 // of two forms:
dan sinclair41e4d9a2022-05-01 14:40:55 +0000151 // arrayLength(&struct_var.array_member)
152 // arrayLength(&array_var)
153 auto* arg = call_expr->args[0];
154 auto* address_of = arg->As<ast::UnaryOpExpression>();
Ben Clayton884f9522023-01-12 22:52:57 +0000155 if (TINT_UNLIKELY(!address_of || address_of->op != ast::UnaryOp::kAddressOf)) {
Ben Claytonc6b38142022-11-03 08:41:19 +0000156 TINT_ICE(Transform, b.Diagnostics())
dan sinclair41e4d9a2022-05-01 14:40:55 +0000157 << "arrayLength() expected address-of, got " << arg->TypeInfo().name;
158 }
159 auto* storage_buffer_expr = address_of->expr;
160 if (auto* accessor = storage_buffer_expr->As<ast::MemberAccessorExpression>()) {
Ben Claytonad315652023-02-05 12:36:50 +0000161 storage_buffer_expr = accessor->object;
dan sinclair41e4d9a2022-05-01 14:40:55 +0000162 }
163 auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
Ben Clayton884f9522023-01-12 22:52:57 +0000164 if (TINT_UNLIKELY(!storage_buffer_sem)) {
Ben Claytonc6b38142022-11-03 08:41:19 +0000165 TINT_ICE(Transform, b.Diagnostics())
dan sinclair41e4d9a2022-05-01 14:40:55 +0000166 << "expected form of arrayLength argument to be &array_var or "
167 "&struct_var.array_member";
168 break;
169 }
170 auto* storage_buffer_var = storage_buffer_sem->Variable();
dan sinclair4d56b482022-12-08 17:50:50 +0000171 auto* storage_buffer_type = storage_buffer_sem->Type()->As<type::Reference>();
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000172
Ben Clayton2032d032022-06-15 19:32:37 +0000173 // Generate BufferSizeIntrinsic for this storage type if we haven't already
dan sinclair41e4d9a2022-05-01 14:40:55 +0000174 auto buffer_size = get_buffer_size_intrinsic(storage_buffer_type);
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000175
dan sinclair41e4d9a2022-05-01 14:40:55 +0000176 // Find the current statement block
177 auto* block = call->Stmt()->Block()->Declaration();
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000178
dan sinclair41e4d9a2022-05-01 14:40:55 +0000179 auto array_length =
180 utils::GetOrCreate(array_length_by_usage, {block, storage_buffer_var}, [&] {
181 // First time this array length is used for this block.
182 // Let's calculate it.
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000183
dan sinclair41e4d9a2022-05-01 14:40:55 +0000184 // Construct the variable that'll hold the result of
185 // RWByteAddressBuffer.GetDimensions()
Ben Claytonc6b38142022-11-03 08:41:19 +0000186 auto* buffer_size_result =
187 b.Decl(b.Var(b.Sym(), b.ty.u32(), b.Expr(0_u)));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000188
dan sinclair41e4d9a2022-05-01 14:40:55 +0000189 // Call storage_buffer.GetDimensions(&buffer_size_result)
Ben Claytonc6b38142022-11-03 08:41:19 +0000190 auto* call_get_dims = b.CallStmt(b.Call(
dan sinclair41e4d9a2022-05-01 14:40:55 +0000191 // BufferSizeIntrinsic(X, ARGS...) is
192 // translated to:
193 // X.GetDimensions(ARGS..) by the writer
Ben Claytonc6b38142022-11-03 08:41:19 +0000194 buffer_size, b.AddressOf(ctx.Clone(storage_buffer_expr)),
Ben Clayton651d9e22023-02-09 10:34:14 +0000195 b.AddressOf(b.Expr(buffer_size_result->variable->name->symbol))));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000196
dan sinclair41e4d9a2022-05-01 14:40:55 +0000197 // Calculate actual array length
198 // total_storage_buffer_size - array_offset
199 // array_length = ----------------------------------------
200 // array_stride
Ben Claytonc6b38142022-11-03 08:41:19 +0000201 auto name = b.Sym();
dan sinclair41e4d9a2022-05-01 14:40:55 +0000202 const ast::Expression* total_size =
Ben Claytonc6b38142022-11-03 08:41:19 +0000203 b.Expr(buffer_size_result->variable);
Ben Clayton2032d032022-06-15 19:32:37 +0000204
dan sinclair946858a2022-12-08 22:21:24 +0000205 const type::Array* array_type = Switch(
Ben Clayton2032d032022-06-15 19:32:37 +0000206 storage_buffer_type->StoreType(),
207 [&](const sem::Struct* str) {
208 // The variable is a struct, so subtract the byte offset of
209 // the array member.
dan sinclairad9cd0a2022-12-06 20:01:54 +0000210 auto* array_member_sem = str->Members().Back();
Ben Claytonc6b38142022-11-03 08:41:19 +0000211 total_size = b.Sub(total_size, u32(array_member_sem->Offset()));
dan sinclair946858a2022-12-08 22:21:24 +0000212 return array_member_sem->Type()->As<type::Array>();
Ben Clayton2032d032022-06-15 19:32:37 +0000213 },
dan sinclair946858a2022-12-08 22:21:24 +0000214 [&](const type::Array* arr) { return arr; });
Ben Clayton2032d032022-06-15 19:32:37 +0000215
Ben Clayton884f9522023-01-12 22:52:57 +0000216 if (TINT_UNLIKELY(!array_type)) {
Ben Claytonc6b38142022-11-03 08:41:19 +0000217 TINT_ICE(Transform, b.Diagnostics())
dan sinclair41e4d9a2022-05-01 14:40:55 +0000218 << "expected form of arrayLength argument to be "
219 "&array_var or &struct_var.array_member";
220 return name;
221 }
Ben Clayton2032d032022-06-15 19:32:37 +0000222
dan sinclair41e4d9a2022-05-01 14:40:55 +0000223 uint32_t array_stride = array_type->Size();
Ben Claytonc6b38142022-11-03 08:41:19 +0000224 auto* array_length_var = b.Decl(
225 b.Let(name, b.ty.u32(), b.Div(total_size, u32(array_stride))));
dan sinclair41e4d9a2022-05-01 14:40:55 +0000226
227 // Insert the array length calculations at the top of the block
228 ctx.InsertBefore(block->statements, block->statements[0],
229 buffer_size_result);
230 ctx.InsertBefore(block->statements, block->statements[0],
231 call_get_dims);
232 ctx.InsertBefore(block->statements, block->statements[0],
233 array_length_var);
234 return name;
235 });
236
237 // Replace the call to arrayLength() with the array length variable
Ben Claytonc6b38142022-11-03 08:41:19 +0000238 ctx.Replace(call_expr, b.Expr(array_length));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000239 }
dan sinclair41e4d9a2022-05-01 14:40:55 +0000240 }
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000241 }
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000242 }
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000243
dan sinclair41e4d9a2022-05-01 14:40:55 +0000244 ctx.Clone();
Ben Claytonc6b38142022-11-03 08:41:19 +0000245 return Program(std::move(b));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000246}
247
dan sinclairb5599d32022-04-07 16:55:14 +0000248} // namespace tint::transform