| // Copyright 2021 The Tint Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "src/tint/transform/decompose_strided_matrix.h" |
| |
| #include <unordered_map> |
| #include <utility> |
| #include <vector> |
| |
| #include "src/tint/program_builder.h" |
| #include "src/tint/sem/expression.h" |
| #include "src/tint/sem/member_accessor_expression.h" |
| #include "src/tint/transform/simplify_pointers.h" |
| #include "src/tint/utils/hash.h" |
| #include "src/tint/utils/map.h" |
| |
| TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeStridedMatrix); |
| |
| namespace tint::transform { |
| namespace { |
| |
| /// MatrixInfo describes a matrix member with a custom stride |
| struct MatrixInfo { |
| /// The stride in bytes between columns of the matrix |
| uint32_t stride = 0; |
| /// The type of the matrix |
| const sem::Matrix* matrix = nullptr; |
| |
| /// @returns a new ast::Array that holds an vector column for each row of the |
| /// matrix. |
| const ast::Array* array(ProgramBuilder* b) const { |
| return b->ty.array(b->ty.vec<f32>(matrix->rows()), u32(matrix->columns()), stride); |
| } |
| |
| /// Equality operator |
| bool operator==(const MatrixInfo& info) const { |
| return stride == info.stride && matrix == info.matrix; |
| } |
| /// Hash function |
| struct Hasher { |
| size_t operator()(const MatrixInfo& t) const { return utils::Hash(t.stride, t.matrix); } |
| }; |
| }; |
| |
| } // namespace |
| |
| DecomposeStridedMatrix::DecomposeStridedMatrix() = default; |
| |
| DecomposeStridedMatrix::~DecomposeStridedMatrix() = default; |
| |
| Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src, |
| const DataMap&, |
| DataMap&) const { |
| ProgramBuilder b; |
| CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; |
| |
| // Scan the program for all storage and uniform structure matrix members with |
| // a custom stride attribute. Replace these matrices with an equivalent array, |
| // and populate the `decomposed` map with the members that have been replaced. |
| utils::Hashmap<const ast::StructMember*, MatrixInfo, 8> decomposed; |
| for (auto* node : src->ASTNodes().Objects()) { |
| if (auto* str = node->As<ast::Struct>()) { |
| auto* str_ty = src->Sem().Get(str); |
| if (!str_ty->UsedAs(ast::AddressSpace::kUniform) && |
| !str_ty->UsedAs(ast::AddressSpace::kStorage)) { |
| continue; |
| } |
| for (auto* member : str_ty->Members()) { |
| auto* matrix = member->Type()->As<sem::Matrix>(); |
| if (!matrix) { |
| continue; |
| } |
| auto* attr = |
| ast::GetAttribute<ast::StrideAttribute>(member->Declaration()->attributes); |
| if (!attr) { |
| continue; |
| } |
| uint32_t stride = attr->stride; |
| if (matrix->ColumnStride() == stride) { |
| continue; |
| } |
| // We've got ourselves a struct member of a matrix type with a custom |
| // stride. Replace this with an array of column vectors. |
| MatrixInfo info{stride, matrix}; |
| auto* replacement = |
| b.Member(member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst)); |
| ctx.Replace(member->Declaration(), replacement); |
| decomposed.Add(member->Declaration(), info); |
| } |
| } |
| } |
| |
| if (decomposed.IsEmpty()) { |
| return SkipTransform; |
| } |
| |
| // For all expressions where a single matrix column vector was indexed, we can |
| // preserve these without calling conversion functions. |
| // Example: |
| // ssbo.mat[2] -> ssbo.mat[2] |
| ctx.ReplaceAll( |
| [&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* { |
| if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr->object)) { |
| if (decomposed.Contains(access->Member()->Declaration())) { |
| auto* obj = ctx.CloneWithoutTransform(expr->object); |
| auto* idx = ctx.Clone(expr->index); |
| return b.IndexAccessor(obj, idx); |
| } |
| } |
| return nullptr; |
| }); |
| |
| // For all struct member accesses to the matrix on the LHS of an assignment, |
| // we need to convert the matrix to the array before assigning to the |
| // structure. |
| // Example: |
| // ssbo.mat = mat_to_arr(m) |
| std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr; |
| ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt) -> const ast::Statement* { |
| if (auto* access = src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) { |
| if (auto* info = decomposed.Find(access->Member()->Declaration())) { |
| auto fn = utils::GetOrCreate(mat_to_arr, *info, [&] { |
| auto name = |
| b.Symbols().New("mat" + std::to_string(info->matrix->columns()) + "x" + |
| std::to_string(info->matrix->rows()) + "_stride_" + |
| std::to_string(info->stride) + "_to_arr"); |
| |
| auto matrix = [&] { return CreateASTTypeFor(ctx, info->matrix); }; |
| auto array = [&] { return info->array(ctx.dst); }; |
| |
| auto mat = b.Sym("m"); |
| utils::Vector<const ast::Expression*, 4> columns; |
| for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) { |
| columns.Push(b.IndexAccessor(mat, u32(i))); |
| } |
| b.Func(name, |
| utils::Vector{ |
| b.Param(mat, matrix()), |
| }, |
| array(), |
| utils::Vector{ |
| b.Return(b.Construct(array(), columns)), |
| }); |
| return name; |
| }); |
| auto* lhs = ctx.CloneWithoutTransform(stmt->lhs); |
| auto* rhs = b.Call(fn, ctx.Clone(stmt->rhs)); |
| return b.Assign(lhs, rhs); |
| } |
| } |
| return nullptr; |
| }); |
| |
| // For all other struct member accesses, we need to convert the array to the |
| // matrix type. Example: |
| // m = arr_to_mat(ssbo.mat) |
| std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat; |
| ctx.ReplaceAll([&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* { |
| if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr)) { |
| if (auto* info = decomposed.Find(access->Member()->Declaration())) { |
| auto fn = utils::GetOrCreate(arr_to_mat, *info, [&] { |
| auto name = |
| b.Symbols().New("arr_to_mat" + std::to_string(info->matrix->columns()) + |
| "x" + std::to_string(info->matrix->rows()) + "_stride_" + |
| std::to_string(info->stride)); |
| |
| auto matrix = [&] { return CreateASTTypeFor(ctx, info->matrix); }; |
| auto array = [&] { return info->array(ctx.dst); }; |
| |
| auto arr = b.Sym("arr"); |
| utils::Vector<const ast::Expression*, 4> columns; |
| for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) { |
| columns.Push(b.IndexAccessor(arr, u32(i))); |
| } |
| b.Func(name, |
| utils::Vector{ |
| b.Param(arr, array()), |
| }, |
| matrix(), |
| utils::Vector{ |
| b.Return(b.Construct(matrix(), columns)), |
| }); |
| return name; |
| }); |
| return b.Call(fn, ctx.CloneWithoutTransform(expr)); |
| } |
| } |
| return nullptr; |
| }); |
| |
| ctx.Clone(); |
| return Program(std::move(b)); |
| } |
| |
| } // namespace tint::transform |