blob: dd03dc232d6b518f82b62684dcb5f54762946029 [file] [log] [blame]
Ryan Harrisondbc13af2022-02-21 15:19:07 +00001// Copyright 2021 The Tint Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "src/tint/transform/decompose_strided_matrix.h"
16
17#include <unordered_map>
18#include <utility>
19#include <vector>
20
21#include "src/tint/program_builder.h"
Ryan Harrisondbc13af2022-02-21 15:19:07 +000022#include "src/tint/sem/member_accessor_expression.h"
Ben Clayton3fb9a3f2023-02-04 21:20:26 +000023#include "src/tint/sem/value_expression.h"
Ryan Harrisondbc13af2022-02-21 15:19:07 +000024#include "src/tint/transform/simplify_pointers.h"
25#include "src/tint/utils/hash.h"
26#include "src/tint/utils/map.h"
27
28TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeStridedMatrix);
29
dan sinclairb5599d32022-04-07 16:55:14 +000030namespace tint::transform {
Ryan Harrisondbc13af2022-02-21 15:19:07 +000031namespace {
32
33/// MatrixInfo describes a matrix member with a custom stride
34struct MatrixInfo {
dan sinclair41e4d9a2022-05-01 14:40:55 +000035 /// The stride in bytes between columns of the matrix
36 uint32_t stride = 0;
37 /// The type of the matrix
dan sinclair0e780da2022-12-08 22:21:24 +000038 const type::Matrix* matrix = nullptr;
Ryan Harrisondbc13af2022-02-21 15:19:07 +000039
Ben Clayton971318f2023-02-14 13:52:43 +000040 /// @returns the identifier of an array that holds an vector column for each row of the matrix.
41 ast::Type array(ProgramBuilder* b) const {
Ben Clayton9e367232023-02-08 14:17:37 +000042 return b->ty.array(b->ty.vec<f32>(matrix->rows()), u32(matrix->columns()),
43 utils::Vector{
44 b->Stride(stride),
45 });
Ryan Harrisondbc13af2022-02-21 15:19:07 +000046 }
dan sinclair41e4d9a2022-05-01 14:40:55 +000047
48 /// Equality operator
49 bool operator==(const MatrixInfo& info) const {
50 return stride == info.stride && matrix == info.matrix;
51 }
52 /// Hash function
53 struct Hasher {
54 size_t operator()(const MatrixInfo& t) const { return utils::Hash(t.stride, t.matrix); }
55 };
Ryan Harrisondbc13af2022-02-21 15:19:07 +000056};
57
Ben Claytonc6b38142022-11-03 08:41:19 +000058} // namespace
Ryan Harrisondbc13af2022-02-21 15:19:07 +000059
Ben Claytonc6b38142022-11-03 08:41:19 +000060DecomposeStridedMatrix::DecomposeStridedMatrix() = default;
61
62DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
63
64Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
65 const DataMap&,
66 DataMap&) const {
67 ProgramBuilder b;
68 CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
69
70 // Scan the program for all storage and uniform structure matrix members with
71 // a custom stride attribute. Replace these matrices with an equivalent array,
72 // and populate the `decomposed` map with the members that have been replaced.
73 utils::Hashmap<const ast::StructMember*, MatrixInfo, 8> decomposed;
74 for (auto* node : src->ASTNodes().Objects()) {
dan sinclair41e4d9a2022-05-01 14:40:55 +000075 if (auto* str = node->As<ast::Struct>()) {
Ben Claytonc6b38142022-11-03 08:41:19 +000076 auto* str_ty = src->Sem().Get(str);
dan sinclair2a651632023-02-19 04:03:55 +000077 if (!str_ty->UsedAs(builtin::AddressSpace::kUniform) &&
78 !str_ty->UsedAs(builtin::AddressSpace::kStorage)) {
dan sinclair41e4d9a2022-05-01 14:40:55 +000079 continue;
80 }
81 for (auto* member : str_ty->Members()) {
dan sinclair0e780da2022-12-08 22:21:24 +000082 auto* matrix = member->Type()->As<type::Matrix>();
dan sinclair41e4d9a2022-05-01 14:40:55 +000083 if (!matrix) {
84 continue;
85 }
86 auto* attr =
87 ast::GetAttribute<ast::StrideAttribute>(member->Declaration()->attributes);
88 if (!attr) {
89 continue;
90 }
91 uint32_t stride = attr->stride;
92 if (matrix->ColumnStride() == stride) {
93 continue;
94 }
Ben Claytonc6b38142022-11-03 08:41:19 +000095 // We've got ourselves a struct member of a matrix type with a custom
96 // stride. Replace this with an array of column vectors.
97 MatrixInfo info{stride, matrix};
98 auto* replacement =
99 b.Member(member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst));
100 ctx.Replace(member->Declaration(), replacement);
101 decomposed.Add(member->Declaration(), info);
dan sinclair41e4d9a2022-05-01 14:40:55 +0000102 }
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000103 }
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000104 }
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000105
Ben Claytonc6b38142022-11-03 08:41:19 +0000106 if (decomposed.IsEmpty()) {
107 return SkipTransform;
108 }
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000109
dan sinclair41e4d9a2022-05-01 14:40:55 +0000110 // For all expressions where a single matrix column vector was indexed, we can
111 // preserve these without calling conversion functions.
112 // Example:
113 // ssbo.mat[2] -> ssbo.mat[2]
114 ctx.ReplaceAll(
115 [&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* {
Ben Claytonc6b38142022-11-03 08:41:19 +0000116 if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr->object)) {
117 if (decomposed.Contains(access->Member()->Declaration())) {
dan sinclair41e4d9a2022-05-01 14:40:55 +0000118 auto* obj = ctx.CloneWithoutTransform(expr->object);
119 auto* idx = ctx.Clone(expr->index);
Ben Claytonc6b38142022-11-03 08:41:19 +0000120 return b.IndexAccessor(obj, idx);
dan sinclair41e4d9a2022-05-01 14:40:55 +0000121 }
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000122 }
dan sinclair41e4d9a2022-05-01 14:40:55 +0000123 return nullptr;
124 });
125
126 // For all struct member accesses to the matrix on the LHS of an assignment,
127 // we need to convert the matrix to the array before assigning to the
128 // structure.
129 // Example:
130 // ssbo.mat = mat_to_arr(m)
131 std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
132 ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt) -> const ast::Statement* {
Ben Claytonc6b38142022-11-03 08:41:19 +0000133 if (auto* access = src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
Ben Clayton7c6e2292022-11-23 21:04:25 +0000134 if (auto info = decomposed.Find(access->Member()->Declaration())) {
Ben Claytonc6b38142022-11-03 08:41:19 +0000135 auto fn = utils::GetOrCreate(mat_to_arr, *info, [&] {
136 auto name =
137 b.Symbols().New("mat" + std::to_string(info->matrix->columns()) + "x" +
138 std::to_string(info->matrix->rows()) + "_stride_" +
139 std::to_string(info->stride) + "_to_arr");
140
141 auto matrix = [&] { return CreateASTTypeFor(ctx, info->matrix); };
142 auto array = [&] { return info->array(ctx.dst); };
143
144 auto mat = b.Sym("m");
145 utils::Vector<const ast::Expression*, 4> columns;
146 for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) {
147 columns.Push(b.IndexAccessor(mat, u32(i)));
148 }
149 b.Func(name,
150 utils::Vector{
151 b.Param(mat, matrix()),
152 },
153 array(),
154 utils::Vector{
Ben Clayton01ac21c2023-02-07 16:14:25 +0000155 b.Return(b.Call(array(), columns)),
Ben Claytonc6b38142022-11-03 08:41:19 +0000156 });
157 return name;
158 });
159 auto* lhs = ctx.CloneWithoutTransform(stmt->lhs);
160 auto* rhs = b.Call(fn, ctx.Clone(stmt->rhs));
161 return b.Assign(lhs, rhs);
dan sinclair41e4d9a2022-05-01 14:40:55 +0000162 }
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000163 }
164 return nullptr;
dan sinclair41e4d9a2022-05-01 14:40:55 +0000165 });
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000166
dan sinclair41e4d9a2022-05-01 14:40:55 +0000167 // For all other struct member accesses, we need to convert the array to the
168 // matrix type. Example:
169 // m = arr_to_mat(ssbo.mat)
170 std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
171 ctx.ReplaceAll([&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* {
Ben Clayton2f9a9882022-12-17 02:20:04 +0000172 if (auto* access = src->Sem().Get(expr)->UnwrapLoad()->As<sem::StructMemberAccess>()) {
Ben Clayton7c6e2292022-11-23 21:04:25 +0000173 if (auto info = decomposed.Find(access->Member()->Declaration())) {
Ben Claytonc6b38142022-11-03 08:41:19 +0000174 auto fn = utils::GetOrCreate(arr_to_mat, *info, [&] {
175 auto name =
176 b.Symbols().New("arr_to_mat" + std::to_string(info->matrix->columns()) +
177 "x" + std::to_string(info->matrix->rows()) + "_stride_" +
178 std::to_string(info->stride));
179
180 auto matrix = [&] { return CreateASTTypeFor(ctx, info->matrix); };
181 auto array = [&] { return info->array(ctx.dst); };
182
183 auto arr = b.Sym("arr");
184 utils::Vector<const ast::Expression*, 4> columns;
185 for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) {
186 columns.Push(b.IndexAccessor(arr, u32(i)));
187 }
188 b.Func(name,
189 utils::Vector{
190 b.Param(arr, array()),
191 },
192 matrix(),
193 utils::Vector{
Ben Clayton01ac21c2023-02-07 16:14:25 +0000194 b.Return(b.Call(matrix(), columns)),
Ben Claytonc6b38142022-11-03 08:41:19 +0000195 });
196 return name;
197 });
198 return b.Call(fn, ctx.CloneWithoutTransform(expr));
dan sinclair41e4d9a2022-05-01 14:40:55 +0000199 }
dan sinclair41e4d9a2022-05-01 14:40:55 +0000200 }
201 return nullptr;
202 });
203
204 ctx.Clone();
Ben Claytonc6b38142022-11-03 08:41:19 +0000205 return Program(std::move(b));
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000206}
207
dan sinclairb5599d32022-04-07 16:55:14 +0000208} // namespace tint::transform