Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 1 | // Copyright 2021 The Tint Authors. |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | // you may not use this file except in compliance with the License. |
| 5 | // You may obtain a copy of the License at |
| 6 | // |
| 7 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | // |
| 9 | // Unless required by applicable law or agreed to in writing, software |
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | // See the License for the specific language governing permissions and |
| 13 | // limitations under the License. |
| 14 | |
| 15 | #include "src/tint/transform/decompose_strided_matrix.h" |
| 16 | |
| 17 | #include <unordered_map> |
| 18 | #include <utility> |
| 19 | #include <vector> |
| 20 | |
| 21 | #include "src/tint/program_builder.h" |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 22 | #include "src/tint/sem/member_accessor_expression.h" |
Ben Clayton | 3fb9a3f | 2023-02-04 21:20:26 +0000 | [diff] [blame] | 23 | #include "src/tint/sem/value_expression.h" |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 24 | #include "src/tint/transform/simplify_pointers.h" |
| 25 | #include "src/tint/utils/hash.h" |
| 26 | #include "src/tint/utils/map.h" |
| 27 | |
| 28 | TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeStridedMatrix); |
| 29 | |
dan sinclair | b5599d3 | 2022-04-07 16:55:14 +0000 | [diff] [blame] | 30 | namespace tint::transform { |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 31 | namespace { |
| 32 | |
| 33 | /// MatrixInfo describes a matrix member with a custom stride |
| 34 | struct MatrixInfo { |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 35 | /// The stride in bytes between columns of the matrix |
| 36 | uint32_t stride = 0; |
| 37 | /// The type of the matrix |
dan sinclair | 0e780da | 2022-12-08 22:21:24 +0000 | [diff] [blame] | 38 | const type::Matrix* matrix = nullptr; |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 39 | |
Ben Clayton | 971318f | 2023-02-14 13:52:43 +0000 | [diff] [blame] | 40 | /// @returns the identifier of an array that holds an vector column for each row of the matrix. |
| 41 | ast::Type array(ProgramBuilder* b) const { |
Ben Clayton | 9e36723 | 2023-02-08 14:17:37 +0000 | [diff] [blame] | 42 | return b->ty.array(b->ty.vec<f32>(matrix->rows()), u32(matrix->columns()), |
| 43 | utils::Vector{ |
| 44 | b->Stride(stride), |
| 45 | }); |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 46 | } |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 47 | |
| 48 | /// Equality operator |
| 49 | bool operator==(const MatrixInfo& info) const { |
| 50 | return stride == info.stride && matrix == info.matrix; |
| 51 | } |
| 52 | /// Hash function |
| 53 | struct Hasher { |
| 54 | size_t operator()(const MatrixInfo& t) const { return utils::Hash(t.stride, t.matrix); } |
| 55 | }; |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 56 | }; |
| 57 | |
Ben Clayton | c6b3814 | 2022-11-03 08:41:19 +0000 | [diff] [blame] | 58 | } // namespace |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 59 | |
Ben Clayton | c6b3814 | 2022-11-03 08:41:19 +0000 | [diff] [blame] | 60 | DecomposeStridedMatrix::DecomposeStridedMatrix() = default; |
| 61 | |
| 62 | DecomposeStridedMatrix::~DecomposeStridedMatrix() = default; |
| 63 | |
| 64 | Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src, |
| 65 | const DataMap&, |
| 66 | DataMap&) const { |
| 67 | ProgramBuilder b; |
| 68 | CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; |
| 69 | |
| 70 | // Scan the program for all storage and uniform structure matrix members with |
| 71 | // a custom stride attribute. Replace these matrices with an equivalent array, |
| 72 | // and populate the `decomposed` map with the members that have been replaced. |
| 73 | utils::Hashmap<const ast::StructMember*, MatrixInfo, 8> decomposed; |
| 74 | for (auto* node : src->ASTNodes().Objects()) { |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 75 | if (auto* str = node->As<ast::Struct>()) { |
Ben Clayton | c6b3814 | 2022-11-03 08:41:19 +0000 | [diff] [blame] | 76 | auto* str_ty = src->Sem().Get(str); |
dan sinclair | 2a65163 | 2023-02-19 04:03:55 +0000 | [diff] [blame] | 77 | if (!str_ty->UsedAs(builtin::AddressSpace::kUniform) && |
| 78 | !str_ty->UsedAs(builtin::AddressSpace::kStorage)) { |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 79 | continue; |
| 80 | } |
| 81 | for (auto* member : str_ty->Members()) { |
dan sinclair | 0e780da | 2022-12-08 22:21:24 +0000 | [diff] [blame] | 82 | auto* matrix = member->Type()->As<type::Matrix>(); |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 83 | if (!matrix) { |
| 84 | continue; |
| 85 | } |
| 86 | auto* attr = |
| 87 | ast::GetAttribute<ast::StrideAttribute>(member->Declaration()->attributes); |
| 88 | if (!attr) { |
| 89 | continue; |
| 90 | } |
| 91 | uint32_t stride = attr->stride; |
| 92 | if (matrix->ColumnStride() == stride) { |
| 93 | continue; |
| 94 | } |
Ben Clayton | c6b3814 | 2022-11-03 08:41:19 +0000 | [diff] [blame] | 95 | // We've got ourselves a struct member of a matrix type with a custom |
| 96 | // stride. Replace this with an array of column vectors. |
| 97 | MatrixInfo info{stride, matrix}; |
| 98 | auto* replacement = |
| 99 | b.Member(member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst)); |
| 100 | ctx.Replace(member->Declaration(), replacement); |
| 101 | decomposed.Add(member->Declaration(), info); |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 102 | } |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 103 | } |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 104 | } |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 105 | |
Ben Clayton | c6b3814 | 2022-11-03 08:41:19 +0000 | [diff] [blame] | 106 | if (decomposed.IsEmpty()) { |
| 107 | return SkipTransform; |
| 108 | } |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 109 | |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 110 | // For all expressions where a single matrix column vector was indexed, we can |
| 111 | // preserve these without calling conversion functions. |
| 112 | // Example: |
| 113 | // ssbo.mat[2] -> ssbo.mat[2] |
| 114 | ctx.ReplaceAll( |
| 115 | [&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* { |
Ben Clayton | c6b3814 | 2022-11-03 08:41:19 +0000 | [diff] [blame] | 116 | if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr->object)) { |
| 117 | if (decomposed.Contains(access->Member()->Declaration())) { |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 118 | auto* obj = ctx.CloneWithoutTransform(expr->object); |
| 119 | auto* idx = ctx.Clone(expr->index); |
Ben Clayton | c6b3814 | 2022-11-03 08:41:19 +0000 | [diff] [blame] | 120 | return b.IndexAccessor(obj, idx); |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 121 | } |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 122 | } |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 123 | return nullptr; |
| 124 | }); |
| 125 | |
| 126 | // For all struct member accesses to the matrix on the LHS of an assignment, |
| 127 | // we need to convert the matrix to the array before assigning to the |
| 128 | // structure. |
| 129 | // Example: |
| 130 | // ssbo.mat = mat_to_arr(m) |
| 131 | std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr; |
| 132 | ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt) -> const ast::Statement* { |
Ben Clayton | c6b3814 | 2022-11-03 08:41:19 +0000 | [diff] [blame] | 133 | if (auto* access = src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) { |
Ben Clayton | 7c6e229 | 2022-11-23 21:04:25 +0000 | [diff] [blame] | 134 | if (auto info = decomposed.Find(access->Member()->Declaration())) { |
Ben Clayton | c6b3814 | 2022-11-03 08:41:19 +0000 | [diff] [blame] | 135 | auto fn = utils::GetOrCreate(mat_to_arr, *info, [&] { |
| 136 | auto name = |
| 137 | b.Symbols().New("mat" + std::to_string(info->matrix->columns()) + "x" + |
| 138 | std::to_string(info->matrix->rows()) + "_stride_" + |
| 139 | std::to_string(info->stride) + "_to_arr"); |
| 140 | |
| 141 | auto matrix = [&] { return CreateASTTypeFor(ctx, info->matrix); }; |
| 142 | auto array = [&] { return info->array(ctx.dst); }; |
| 143 | |
| 144 | auto mat = b.Sym("m"); |
| 145 | utils::Vector<const ast::Expression*, 4> columns; |
| 146 | for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) { |
| 147 | columns.Push(b.IndexAccessor(mat, u32(i))); |
| 148 | } |
| 149 | b.Func(name, |
| 150 | utils::Vector{ |
| 151 | b.Param(mat, matrix()), |
| 152 | }, |
| 153 | array(), |
| 154 | utils::Vector{ |
Ben Clayton | 01ac21c | 2023-02-07 16:14:25 +0000 | [diff] [blame] | 155 | b.Return(b.Call(array(), columns)), |
Ben Clayton | c6b3814 | 2022-11-03 08:41:19 +0000 | [diff] [blame] | 156 | }); |
| 157 | return name; |
| 158 | }); |
| 159 | auto* lhs = ctx.CloneWithoutTransform(stmt->lhs); |
| 160 | auto* rhs = b.Call(fn, ctx.Clone(stmt->rhs)); |
| 161 | return b.Assign(lhs, rhs); |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 162 | } |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 163 | } |
| 164 | return nullptr; |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 165 | }); |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 166 | |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 167 | // For all other struct member accesses, we need to convert the array to the |
| 168 | // matrix type. Example: |
| 169 | // m = arr_to_mat(ssbo.mat) |
| 170 | std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat; |
| 171 | ctx.ReplaceAll([&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* { |
Ben Clayton | 2f9a988 | 2022-12-17 02:20:04 +0000 | [diff] [blame] | 172 | if (auto* access = src->Sem().Get(expr)->UnwrapLoad()->As<sem::StructMemberAccess>()) { |
Ben Clayton | 7c6e229 | 2022-11-23 21:04:25 +0000 | [diff] [blame] | 173 | if (auto info = decomposed.Find(access->Member()->Declaration())) { |
Ben Clayton | c6b3814 | 2022-11-03 08:41:19 +0000 | [diff] [blame] | 174 | auto fn = utils::GetOrCreate(arr_to_mat, *info, [&] { |
| 175 | auto name = |
| 176 | b.Symbols().New("arr_to_mat" + std::to_string(info->matrix->columns()) + |
| 177 | "x" + std::to_string(info->matrix->rows()) + "_stride_" + |
| 178 | std::to_string(info->stride)); |
| 179 | |
| 180 | auto matrix = [&] { return CreateASTTypeFor(ctx, info->matrix); }; |
| 181 | auto array = [&] { return info->array(ctx.dst); }; |
| 182 | |
| 183 | auto arr = b.Sym("arr"); |
| 184 | utils::Vector<const ast::Expression*, 4> columns; |
| 185 | for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) { |
| 186 | columns.Push(b.IndexAccessor(arr, u32(i))); |
| 187 | } |
| 188 | b.Func(name, |
| 189 | utils::Vector{ |
| 190 | b.Param(arr, array()), |
| 191 | }, |
| 192 | matrix(), |
| 193 | utils::Vector{ |
Ben Clayton | 01ac21c | 2023-02-07 16:14:25 +0000 | [diff] [blame] | 194 | b.Return(b.Call(matrix(), columns)), |
Ben Clayton | c6b3814 | 2022-11-03 08:41:19 +0000 | [diff] [blame] | 195 | }); |
| 196 | return name; |
| 197 | }); |
| 198 | return b.Call(fn, ctx.CloneWithoutTransform(expr)); |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 199 | } |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 200 | } |
| 201 | return nullptr; |
| 202 | }); |
| 203 | |
| 204 | ctx.Clone(); |
Ben Clayton | c6b3814 | 2022-11-03 08:41:19 +0000 | [diff] [blame] | 205 | return Program(std::move(b)); |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 206 | } |
| 207 | |
dan sinclair | b5599d3 | 2022-04-07 16:55:14 +0000 | [diff] [blame] | 208 | } // namespace tint::transform |