| // Copyright 2021 The Tint Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "src/transform/calculate_array_length.h" |
| |
| #include <unordered_map> |
| #include <utility> |
| |
| #include "src/ast/call_statement.h" |
| #include "src/ast/disable_validation_decoration.h" |
| #include "src/program_builder.h" |
| #include "src/sem/block_statement.h" |
| #include "src/sem/call.h" |
| #include "src/sem/statement.h" |
| #include "src/sem/struct.h" |
| #include "src/sem/variable.h" |
| #include "src/transform/simplify_pointers.h" |
| #include "src/utils/hash.h" |
| #include "src/utils/map.h" |
| |
| TINT_INSTANTIATE_TYPEINFO(tint::transform::CalculateArrayLength); |
| TINT_INSTANTIATE_TYPEINFO( |
| tint::transform::CalculateArrayLength::BufferSizeIntrinsic); |
| |
| namespace tint { |
| namespace transform { |
| |
| namespace { |
| |
| /// ArrayUsage describes a runtime array usage. |
| /// It is used as a key by the array_length_by_usage map. |
| struct ArrayUsage { |
| ast::BlockStatement const* const block; |
| sem::Node const* const buffer; |
| bool operator==(const ArrayUsage& rhs) const { |
| return block == rhs.block && buffer == rhs.buffer; |
| } |
| struct Hasher { |
| inline std::size_t operator()(const ArrayUsage& u) const { |
| return utils::Hash(u.block, u.buffer); |
| } |
| }; |
| }; |
| |
| } // namespace |
| |
| CalculateArrayLength::BufferSizeIntrinsic::BufferSizeIntrinsic(ProgramID pid) |
| : Base(pid) {} |
| CalculateArrayLength::BufferSizeIntrinsic::~BufferSizeIntrinsic() = default; |
| std::string CalculateArrayLength::BufferSizeIntrinsic::InternalName() const { |
| return "intrinsic_buffer_size"; |
| } |
| |
| const CalculateArrayLength::BufferSizeIntrinsic* |
| CalculateArrayLength::BufferSizeIntrinsic::Clone(CloneContext* ctx) const { |
| return ctx->dst->ASTNodes().Create<CalculateArrayLength::BufferSizeIntrinsic>( |
| ctx->dst->ID()); |
| } |
| |
| CalculateArrayLength::CalculateArrayLength() = default; |
| CalculateArrayLength::~CalculateArrayLength() = default; |
| |
| void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) { |
| auto& sem = ctx.src->Sem(); |
| if (!Requires<SimplifyPointers>(ctx)) { |
| return; |
| } |
| |
| // get_buffer_size_intrinsic() emits the function decorated with |
| // BufferSizeIntrinsic that is transformed by the HLSL writer into a call to |
| // [RW]ByteAddressBuffer.GetDimensions(). |
| std::unordered_map<const sem::Struct*, Symbol> buffer_size_intrinsics; |
| auto get_buffer_size_intrinsic = [&](const sem::Struct* buffer_type) { |
| return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] { |
| auto name = ctx.dst->Sym(); |
| auto* buffer_typename = |
| ctx.dst->ty.type_name(ctx.Clone(buffer_type->Declaration()->name)); |
| auto* disable_validation = ctx.dst->Disable( |
| ast::DisabledValidation::kIgnoreConstructibleFunctionParameter); |
| auto* func = ctx.dst->create<ast::Function>( |
| name, |
| ast::VariableList{ |
| // Note: The buffer parameter requires the kStorage StorageClass |
| // in order for HLSL to emit this as a ByteAddressBuffer. |
| ctx.dst->create<ast::Variable>( |
| ctx.dst->Sym("buffer"), ast::StorageClass::kStorage, |
| ast::Access::kUndefined, buffer_typename, true, nullptr, |
| ast::DecorationList{disable_validation}), |
| ctx.dst->Param("result", |
| ctx.dst->ty.pointer(ctx.dst->ty.u32(), |
| ast::StorageClass::kFunction)), |
| }, |
| ctx.dst->ty.void_(), nullptr, |
| ast::DecorationList{ |
| ctx.dst->ASTNodes().Create<BufferSizeIntrinsic>(ctx.dst->ID()), |
| }, |
| ast::DecorationList{}); |
| ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), |
| buffer_type->Declaration(), func); |
| return name; |
| }); |
| }; |
| |
| std::unordered_map<ArrayUsage, Symbol, ArrayUsage::Hasher> |
| array_length_by_usage; |
| |
| // Find all the arrayLength() calls... |
| for (auto* node : ctx.src->ASTNodes().Objects()) { |
| if (auto* call_expr = node->As<ast::CallExpression>()) { |
| auto* call = sem.Get(call_expr); |
| if (auto* intrinsic = call->Target()->As<sem::Intrinsic>()) { |
| if (intrinsic->Type() == sem::IntrinsicType::kArrayLength) { |
| // We're dealing with an arrayLength() call |
| |
| // https://gpuweb.github.io/gpuweb/wgsl/#array-types states: |
| // |
| // * The last member of the structure type defining the store type for |
| // a variable in the storage storage class may be a runtime-sized |
| // array. |
| // * A runtime-sized array must not be used as the store type or |
| // contained within a store type in any other cases. |
| // * An expression must not evaluate to a runtime-sized array type. |
| // |
| // We can assume that the arrayLength() call has a single argument of |
| // the form: arrayLength(&X.Y) where X is an expression that resolves |
| // to the storage buffer structure, and Y is the runtime sized array. |
| auto* arg = call_expr->args[0]; |
| auto* address_of = arg->As<ast::UnaryOpExpression>(); |
| if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) { |
| TINT_ICE(Transform, ctx.dst->Diagnostics()) |
| << "arrayLength() expected pointer to member access, got " |
| << address_of->TypeInfo().name; |
| } |
| auto* array_expr = address_of->expr; |
| |
| auto* accessor = array_expr->As<ast::MemberAccessorExpression>(); |
| if (!accessor) { |
| TINT_ICE(Transform, ctx.dst->Diagnostics()) |
| << "arrayLength() expected pointer to member access, got " |
| "pointer to " |
| << array_expr->TypeInfo().name; |
| break; |
| } |
| auto* storage_buffer_expr = accessor->structure; |
| auto* storage_buffer_sem = sem.Get(storage_buffer_expr); |
| auto* storage_buffer_type = |
| storage_buffer_sem->Type()->UnwrapRef()->As<sem::Struct>(); |
| |
| // Generate BufferSizeIntrinsic for this storage type if we haven't |
| // already |
| auto buffer_size = get_buffer_size_intrinsic(storage_buffer_type); |
| |
| if (!storage_buffer_type) { |
| TINT_ICE(Transform, ctx.dst->Diagnostics()) |
| << "arrayLength(X.Y) expected X to be sem::Struct, got " |
| << storage_buffer_type->FriendlyName(ctx.src->Symbols()); |
| break; |
| } |
| |
| // Find the current statement block |
| auto* block = call->Stmt()->Block()->Declaration(); |
| |
| // If the storage_buffer_expr is resolves to a variable (typically |
| // true) then key the array_length from the variable. If not, key off |
| // the expression semantic node, which will be unique per call to |
| // arrayLength(). |
| const sem::Node* storage_buffer_usage = storage_buffer_sem; |
| if (auto* user = storage_buffer_sem->As<sem::VariableUser>()) { |
| storage_buffer_usage = user->Variable(); |
| } |
| |
| auto array_length = utils::GetOrCreate( |
| array_length_by_usage, {block, storage_buffer_usage}, [&] { |
| // First time this array length is used for this block. |
| // Let's calculate it. |
| |
| // Semantic info for the runtime array structure member |
| auto* array_member_sem = storage_buffer_type->Members().back(); |
| |
| // Construct the variable that'll hold the result of |
| // RWByteAddressBuffer.GetDimensions() |
| auto* buffer_size_result = ctx.dst->Decl( |
| ctx.dst->Var(ctx.dst->Sym(), ctx.dst->ty.u32(), |
| ast::StorageClass::kNone, ctx.dst->Expr(0u))); |
| |
| // Call storage_buffer.GetDimensions(&buffer_size_result) |
| auto* call_get_dims = ctx.dst->CallStmt(ctx.dst->Call( |
| // BufferSizeIntrinsic(X, ARGS...) is |
| // translated to: |
| // X.GetDimensions(ARGS..) by the writer |
| buffer_size, ctx.Clone(storage_buffer_expr), |
| ctx.dst->AddressOf( |
| ctx.dst->Expr(buffer_size_result->variable->symbol)))); |
| |
| // Calculate actual array length |
| // total_storage_buffer_size - array_offset |
| // array_length = ---------------------------------------- |
| // array_stride |
| auto name = ctx.dst->Sym(); |
| uint32_t array_offset = array_member_sem->Offset(); |
| uint32_t array_stride = array_member_sem->Size(); |
| auto* array_length_var = ctx.dst->Decl(ctx.dst->Const( |
| name, ctx.dst->ty.u32(), |
| ctx.dst->Div( |
| ctx.dst->Sub(buffer_size_result->variable->symbol, |
| array_offset), |
| array_stride))); |
| |
| // Insert the array length calculations at the top of the block |
| ctx.InsertBefore(block->statements, block->statements[0], |
| buffer_size_result); |
| ctx.InsertBefore(block->statements, block->statements[0], |
| call_get_dims); |
| ctx.InsertBefore(block->statements, block->statements[0], |
| array_length_var); |
| return name; |
| }); |
| |
| // Replace the call to arrayLength() with the array length variable |
| ctx.Replace(call_expr, ctx.dst->Expr(array_length)); |
| } |
| } |
| } |
| } |
| |
| ctx.Clone(); |
| } |
| |
| } // namespace transform |
| } // namespace tint |