|  | // Copyright 2021 The Tint Authors. | 
|  | // | 
|  | // Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | // you may not use this file except in compliance with the License. | 
|  | // You may obtain a copy of the License at | 
|  | // | 
|  | //     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | // | 
|  | // Unless required by applicable law or agreed to in writing, software | 
|  | // distributed under the License is distributed on an "AS IS" BASIS, | 
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | // See the License for the specific language governing permissions and | 
|  | // limitations under the License. | 
|  |  | 
|  | #include "src/transform/calculate_array_length.h" | 
|  |  | 
|  | #include <unordered_map> | 
|  | #include <utility> | 
|  |  | 
|  | #include "src/ast/call_statement.h" | 
|  | #include "src/ast/disable_validation_decoration.h" | 
|  | #include "src/program_builder.h" | 
|  | #include "src/sem/block_statement.h" | 
|  | #include "src/sem/call.h" | 
|  | #include "src/sem/statement.h" | 
|  | #include "src/sem/struct.h" | 
|  | #include "src/sem/variable.h" | 
|  | #include "src/transform/simplify_pointers.h" | 
|  | #include "src/utils/hash.h" | 
|  | #include "src/utils/map.h" | 
|  |  | 
|  | TINT_INSTANTIATE_TYPEINFO(tint::transform::CalculateArrayLength); | 
|  | TINT_INSTANTIATE_TYPEINFO( | 
|  | tint::transform::CalculateArrayLength::BufferSizeIntrinsic); | 
|  |  | 
|  | namespace tint { | 
|  | namespace transform { | 
|  |  | 
|  | namespace { | 
|  |  | 
|  | /// ArrayUsage describes a runtime array usage. | 
|  | /// It is used as a key by the array_length_by_usage map. | 
|  | struct ArrayUsage { | 
|  | ast::BlockStatement const* const block; | 
|  | sem::Node const* const buffer; | 
|  | bool operator==(const ArrayUsage& rhs) const { | 
|  | return block == rhs.block && buffer == rhs.buffer; | 
|  | } | 
|  | struct Hasher { | 
|  | inline std::size_t operator()(const ArrayUsage& u) const { | 
|  | return utils::Hash(u.block, u.buffer); | 
|  | } | 
|  | }; | 
|  | }; | 
|  |  | 
|  | }  // namespace | 
|  |  | 
|  | CalculateArrayLength::BufferSizeIntrinsic::BufferSizeIntrinsic(ProgramID pid) | 
|  | : Base(pid) {} | 
|  | CalculateArrayLength::BufferSizeIntrinsic::~BufferSizeIntrinsic() = default; | 
|  | std::string CalculateArrayLength::BufferSizeIntrinsic::InternalName() const { | 
|  | return "intrinsic_buffer_size"; | 
|  | } | 
|  |  | 
|  | const CalculateArrayLength::BufferSizeIntrinsic* | 
|  | CalculateArrayLength::BufferSizeIntrinsic::Clone(CloneContext* ctx) const { | 
|  | return ctx->dst->ASTNodes().Create<CalculateArrayLength::BufferSizeIntrinsic>( | 
|  | ctx->dst->ID()); | 
|  | } | 
|  |  | 
|  | CalculateArrayLength::CalculateArrayLength() = default; | 
|  | CalculateArrayLength::~CalculateArrayLength() = default; | 
|  |  | 
|  | void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) { | 
|  | auto& sem = ctx.src->Sem(); | 
|  | if (!Requires<SimplifyPointers>(ctx)) { | 
|  | return; | 
|  | } | 
|  |  | 
|  | // get_buffer_size_intrinsic() emits the function decorated with | 
|  | // BufferSizeIntrinsic that is transformed by the HLSL writer into a call to | 
|  | // [RW]ByteAddressBuffer.GetDimensions(). | 
|  | std::unordered_map<const sem::Struct*, Symbol> buffer_size_intrinsics; | 
|  | auto get_buffer_size_intrinsic = [&](const sem::Struct* buffer_type) { | 
|  | return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] { | 
|  | auto name = ctx.dst->Sym(); | 
|  | auto* buffer_typename = | 
|  | ctx.dst->ty.type_name(ctx.Clone(buffer_type->Declaration()->name)); | 
|  | auto* disable_validation = ctx.dst->Disable( | 
|  | ast::DisabledValidation::kIgnoreConstructibleFunctionParameter); | 
|  | auto* func = ctx.dst->create<ast::Function>( | 
|  | name, | 
|  | ast::VariableList{ | 
|  | // Note: The buffer parameter requires the kStorage StorageClass | 
|  | // in order for HLSL to emit this as a ByteAddressBuffer. | 
|  | ctx.dst->create<ast::Variable>( | 
|  | ctx.dst->Sym("buffer"), ast::StorageClass::kStorage, | 
|  | ast::Access::kUndefined, buffer_typename, true, nullptr, | 
|  | ast::DecorationList{disable_validation}), | 
|  | ctx.dst->Param("result", | 
|  | ctx.dst->ty.pointer(ctx.dst->ty.u32(), | 
|  | ast::StorageClass::kFunction)), | 
|  | }, | 
|  | ctx.dst->ty.void_(), nullptr, | 
|  | ast::DecorationList{ | 
|  | ctx.dst->ASTNodes().Create<BufferSizeIntrinsic>(ctx.dst->ID()), | 
|  | }, | 
|  | ast::DecorationList{}); | 
|  | ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), | 
|  | buffer_type->Declaration(), func); | 
|  | return name; | 
|  | }); | 
|  | }; | 
|  |  | 
|  | std::unordered_map<ArrayUsage, Symbol, ArrayUsage::Hasher> | 
|  | array_length_by_usage; | 
|  |  | 
|  | // Find all the arrayLength() calls... | 
|  | for (auto* node : ctx.src->ASTNodes().Objects()) { | 
|  | if (auto* call_expr = node->As<ast::CallExpression>()) { | 
|  | auto* call = sem.Get(call_expr); | 
|  | if (auto* intrinsic = call->Target()->As<sem::Intrinsic>()) { | 
|  | if (intrinsic->Type() == sem::IntrinsicType::kArrayLength) { | 
|  | // We're dealing with an arrayLength() call | 
|  |  | 
|  | // https://gpuweb.github.io/gpuweb/wgsl/#array-types states: | 
|  | // | 
|  | // * The last member of the structure type defining the store type for | 
|  | //   a variable in the storage storage class may be a runtime-sized | 
|  | //   array. | 
|  | // * A runtime-sized array must not be used as the store type or | 
|  | //   contained within a store type in any other cases. | 
|  | // * An expression must not evaluate to a runtime-sized array type. | 
|  | // | 
|  | // We can assume that the arrayLength() call has a single argument of | 
|  | // the form: arrayLength(&X.Y) where X is an expression that resolves | 
|  | // to the storage buffer structure, and Y is the runtime sized array. | 
|  | auto* arg = call_expr->args[0]; | 
|  | auto* address_of = arg->As<ast::UnaryOpExpression>(); | 
|  | if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) { | 
|  | TINT_ICE(Transform, ctx.dst->Diagnostics()) | 
|  | << "arrayLength() expected pointer to member access, got " | 
|  | << address_of->TypeInfo().name; | 
|  | } | 
|  | auto* array_expr = address_of->expr; | 
|  |  | 
|  | auto* accessor = array_expr->As<ast::MemberAccessorExpression>(); | 
|  | if (!accessor) { | 
|  | TINT_ICE(Transform, ctx.dst->Diagnostics()) | 
|  | << "arrayLength() expected pointer to member access, got " | 
|  | "pointer to " | 
|  | << array_expr->TypeInfo().name; | 
|  | break; | 
|  | } | 
|  | auto* storage_buffer_expr = accessor->structure; | 
|  | auto* storage_buffer_sem = sem.Get(storage_buffer_expr); | 
|  | auto* storage_buffer_type = | 
|  | storage_buffer_sem->Type()->UnwrapRef()->As<sem::Struct>(); | 
|  |  | 
|  | // Generate BufferSizeIntrinsic for this storage type if we haven't | 
|  | // already | 
|  | auto buffer_size = get_buffer_size_intrinsic(storage_buffer_type); | 
|  |  | 
|  | if (!storage_buffer_type) { | 
|  | TINT_ICE(Transform, ctx.dst->Diagnostics()) | 
|  | << "arrayLength(X.Y) expected X to be sem::Struct, got " | 
|  | << storage_buffer_type->FriendlyName(ctx.src->Symbols()); | 
|  | break; | 
|  | } | 
|  |  | 
|  | // Find the current statement block | 
|  | auto* block = call->Stmt()->Block()->Declaration(); | 
|  |  | 
|  | // If the storage_buffer_expr is resolves to a variable (typically | 
|  | // true) then key the array_length from the variable. If not, key off | 
|  | // the expression semantic node, which will be unique per call to | 
|  | // arrayLength(). | 
|  | const sem::Node* storage_buffer_usage = storage_buffer_sem; | 
|  | if (auto* user = storage_buffer_sem->As<sem::VariableUser>()) { | 
|  | storage_buffer_usage = user->Variable(); | 
|  | } | 
|  |  | 
|  | auto array_length = utils::GetOrCreate( | 
|  | array_length_by_usage, {block, storage_buffer_usage}, [&] { | 
|  | // First time this array length is used for this block. | 
|  | // Let's calculate it. | 
|  |  | 
|  | // Semantic info for the runtime array structure member | 
|  | auto* array_member_sem = storage_buffer_type->Members().back(); | 
|  |  | 
|  | // Construct the variable that'll hold the result of | 
|  | // RWByteAddressBuffer.GetDimensions() | 
|  | auto* buffer_size_result = ctx.dst->Decl( | 
|  | ctx.dst->Var(ctx.dst->Sym(), ctx.dst->ty.u32(), | 
|  | ast::StorageClass::kNone, ctx.dst->Expr(0u))); | 
|  |  | 
|  | // Call storage_buffer.GetDimensions(&buffer_size_result) | 
|  | auto* call_get_dims = ctx.dst->CallStmt(ctx.dst->Call( | 
|  | // BufferSizeIntrinsic(X, ARGS...) is | 
|  | // translated to: | 
|  | //  X.GetDimensions(ARGS..) by the writer | 
|  | buffer_size, ctx.Clone(storage_buffer_expr), | 
|  | ctx.dst->AddressOf( | 
|  | ctx.dst->Expr(buffer_size_result->variable->symbol)))); | 
|  |  | 
|  | // Calculate actual array length | 
|  | //                total_storage_buffer_size - array_offset | 
|  | // array_length = ---------------------------------------- | 
|  | //                             array_stride | 
|  | auto name = ctx.dst->Sym(); | 
|  | uint32_t array_offset = array_member_sem->Offset(); | 
|  | uint32_t array_stride = array_member_sem->Size(); | 
|  | auto* array_length_var = ctx.dst->Decl(ctx.dst->Const( | 
|  | name, ctx.dst->ty.u32(), | 
|  | ctx.dst->Div( | 
|  | ctx.dst->Sub(buffer_size_result->variable->symbol, | 
|  | array_offset), | 
|  | array_stride))); | 
|  |  | 
|  | // Insert the array length calculations at the top of the block | 
|  | ctx.InsertBefore(block->statements, block->statements[0], | 
|  | buffer_size_result); | 
|  | ctx.InsertBefore(block->statements, block->statements[0], | 
|  | call_get_dims); | 
|  | ctx.InsertBefore(block->statements, block->statements[0], | 
|  | array_length_var); | 
|  | return name; | 
|  | }); | 
|  |  | 
|  | // Replace the call to arrayLength() with the array length variable | 
|  | ctx.Replace(call_expr, ctx.dst->Expr(array_length)); | 
|  | } | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | ctx.Clone(); | 
|  | } | 
|  |  | 
|  | }  // namespace transform | 
|  | }  // namespace tint |