tint/transform: Refactor transforms
Replace the ShouldRun() method with Apply() which will do the
transformation if it needs to be done, otherwise returns
'SkipTransform'.
This reduces a bunch of duplicated scanning between the old ShouldRun()
and Transform().
This change also adjusts code style to make the transforms more
consistent.
Change-Id: I9a6b10cb8b4ed62676b12ef30fb7764d363386c6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/107681
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
diff --git a/src/tint/transform/decompose_strided_matrix.cc b/src/tint/transform/decompose_strided_matrix.cc
index 91aed43..5494ca2 100644
--- a/src/tint/transform/decompose_strided_matrix.cc
+++ b/src/tint/transform/decompose_strided_matrix.cc
@@ -53,24 +53,25 @@
};
};
-/// Return type of the callback function of GatherCustomStrideMatrixMembers
-enum GatherResult { kContinue, kStop };
+} // namespace
-/// GatherCustomStrideMatrixMembers scans `program` for all matrix members of
-/// storage and uniform structs, which are of a matrix type, and have a custom
-/// matrix stride attribute. For each matrix member found, `callback` is called.
-/// `callback` is a function with the signature:
-/// GatherResult(const sem::StructMember* member,
-/// sem::Matrix* matrix,
-/// uint32_t stride)
-/// If `callback` return GatherResult::kStop, then the scanning will immediately
-/// terminate, and GatherCustomStrideMatrixMembers() will return, otherwise
-/// scanning will continue.
-template <typename F>
-void GatherCustomStrideMatrixMembers(const Program* program, F&& callback) {
- for (auto* node : program->ASTNodes().Objects()) {
+DecomposeStridedMatrix::DecomposeStridedMatrix() = default;
+
+DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
+
+Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ // Scan the program for all storage and uniform structure matrix members with
+ // a custom stride attribute. Replace these matrices with an equivalent array,
+ // and populate the `decomposed` map with the members that have been replaced.
+ utils::Hashmap<const ast::StructMember*, MatrixInfo, 8> decomposed;
+ for (auto* node : src->ASTNodes().Objects()) {
if (auto* str = node->As<ast::Struct>()) {
- auto* str_ty = program->Sem().Get(str);
+ auto* str_ty = src->Sem().Get(str);
if (!str_ty->UsedAs(ast::AddressSpace::kUniform) &&
!str_ty->UsedAs(ast::AddressSpace::kStorage)) {
continue;
@@ -89,46 +90,20 @@
if (matrix->ColumnStride() == stride) {
continue;
}
- if (callback(member, matrix, stride) == GatherResult::kStop) {
- return;
- }
+ // We've got ourselves a struct member of a matrix type with a custom
+ // stride. Replace this with an array of column vectors.
+ MatrixInfo info{stride, matrix};
+ auto* replacement =
+ b.Member(member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst));
+ ctx.Replace(member->Declaration(), replacement);
+ decomposed.Add(member->Declaration(), info);
}
}
}
-}
-} // namespace
-
-DecomposeStridedMatrix::DecomposeStridedMatrix() = default;
-
-DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
-
-bool DecomposeStridedMatrix::ShouldRun(const Program* program, const DataMap&) const {
- bool should_run = false;
- GatherCustomStrideMatrixMembers(program,
- [&](const sem::StructMember*, const sem::Matrix*, uint32_t) {
- should_run = true;
- return GatherResult::kStop;
- });
- return should_run;
-}
-
-void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- // Scan the program for all storage and uniform structure matrix members with
- // a custom stride attribute. Replace these matrices with an equivalent array,
- // and populate the `decomposed` map with the members that have been replaced.
- std::unordered_map<const ast::StructMember*, MatrixInfo> decomposed;
- GatherCustomStrideMatrixMembers(
- ctx.src, [&](const sem::StructMember* member, const sem::Matrix* matrix, uint32_t stride) {
- // We've got ourselves a struct member of a matrix type with a custom
- // stride. Replace this with an array of column vectors.
- MatrixInfo info{stride, matrix};
- auto* replacement =
- ctx.dst->Member(member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst));
- ctx.Replace(member->Declaration(), replacement);
- decomposed.emplace(member->Declaration(), info);
- return GatherResult::kContinue;
- });
+ if (decomposed.IsEmpty()) {
+ return SkipTransform;
+ }
// For all expressions where a single matrix column vector was indexed, we can
// preserve these without calling conversion functions.
@@ -136,12 +111,11 @@
// ssbo.mat[2] -> ssbo.mat[2]
ctx.ReplaceAll(
[&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* {
- if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr->object)) {
- auto it = decomposed.find(access->Member()->Declaration());
- if (it != decomposed.end()) {
+ if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr->object)) {
+ if (decomposed.Contains(access->Member()->Declaration())) {
auto* obj = ctx.CloneWithoutTransform(expr->object);
auto* idx = ctx.Clone(expr->index);
- return ctx.dst->IndexAccessor(obj, idx);
+ return b.IndexAccessor(obj, idx);
}
}
return nullptr;
@@ -154,39 +128,36 @@
// ssbo.mat = mat_to_arr(m)
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt) -> const ast::Statement* {
- if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
- auto it = decomposed.find(access->Member()->Declaration());
- if (it == decomposed.end()) {
- return nullptr;
+ if (auto* access = src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
+ if (auto* info = decomposed.Find(access->Member()->Declaration())) {
+ auto fn = utils::GetOrCreate(mat_to_arr, *info, [&] {
+ auto name =
+ b.Symbols().New("mat" + std::to_string(info->matrix->columns()) + "x" +
+ std::to_string(info->matrix->rows()) + "_stride_" +
+ std::to_string(info->stride) + "_to_arr");
+
+ auto matrix = [&] { return CreateASTTypeFor(ctx, info->matrix); };
+ auto array = [&] { return info->array(ctx.dst); };
+
+ auto mat = b.Sym("m");
+ utils::Vector<const ast::Expression*, 4> columns;
+ for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) {
+ columns.Push(b.IndexAccessor(mat, u32(i)));
+ }
+ b.Func(name,
+ utils::Vector{
+ b.Param(mat, matrix()),
+ },
+ array(),
+ utils::Vector{
+ b.Return(b.Construct(array(), columns)),
+ });
+ return name;
+ });
+ auto* lhs = ctx.CloneWithoutTransform(stmt->lhs);
+ auto* rhs = b.Call(fn, ctx.Clone(stmt->rhs));
+ return b.Assign(lhs, rhs);
}
- MatrixInfo info = it->second;
- auto fn = utils::GetOrCreate(mat_to_arr, info, [&] {
- auto name =
- ctx.dst->Symbols().New("mat" + std::to_string(info.matrix->columns()) + "x" +
- std::to_string(info.matrix->rows()) + "_stride_" +
- std::to_string(info.stride) + "_to_arr");
-
- auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
- auto array = [&] { return info.array(ctx.dst); };
-
- auto mat = ctx.dst->Sym("m");
- utils::Vector<const ast::Expression*, 4> columns;
- for (uint32_t i = 0; i < static_cast<uint32_t>(info.matrix->columns()); i++) {
- columns.Push(ctx.dst->IndexAccessor(mat, u32(i)));
- }
- ctx.dst->Func(name,
- utils::Vector{
- ctx.dst->Param(mat, matrix()),
- },
- array(),
- utils::Vector{
- ctx.dst->Return(ctx.dst->Construct(array(), columns)),
- });
- return name;
- });
- auto* lhs = ctx.CloneWithoutTransform(stmt->lhs);
- auto* rhs = ctx.dst->Call(fn, ctx.Clone(stmt->rhs));
- return ctx.dst->Assign(lhs, rhs);
}
return nullptr;
});
@@ -196,41 +167,40 @@
// m = arr_to_mat(ssbo.mat)
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
ctx.ReplaceAll([&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* {
- if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr)) {
- auto it = decomposed.find(access->Member()->Declaration());
- if (it == decomposed.end()) {
- return nullptr;
+ if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr)) {
+ if (auto* info = decomposed.Find(access->Member()->Declaration())) {
+ auto fn = utils::GetOrCreate(arr_to_mat, *info, [&] {
+ auto name =
+ b.Symbols().New("arr_to_mat" + std::to_string(info->matrix->columns()) +
+ "x" + std::to_string(info->matrix->rows()) + "_stride_" +
+ std::to_string(info->stride));
+
+ auto matrix = [&] { return CreateASTTypeFor(ctx, info->matrix); };
+ auto array = [&] { return info->array(ctx.dst); };
+
+ auto arr = b.Sym("arr");
+ utils::Vector<const ast::Expression*, 4> columns;
+ for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) {
+ columns.Push(b.IndexAccessor(arr, u32(i)));
+ }
+ b.Func(name,
+ utils::Vector{
+ b.Param(arr, array()),
+ },
+ matrix(),
+ utils::Vector{
+ b.Return(b.Construct(matrix(), columns)),
+ });
+ return name;
+ });
+ return b.Call(fn, ctx.CloneWithoutTransform(expr));
}
- MatrixInfo info = it->second;
- auto fn = utils::GetOrCreate(arr_to_mat, info, [&] {
- auto name = ctx.dst->Symbols().New(
- "arr_to_mat" + std::to_string(info.matrix->columns()) + "x" +
- std::to_string(info.matrix->rows()) + "_stride_" + std::to_string(info.stride));
-
- auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
- auto array = [&] { return info.array(ctx.dst); };
-
- auto arr = ctx.dst->Sym("arr");
- utils::Vector<const ast::Expression*, 4> columns;
- for (uint32_t i = 0; i < static_cast<uint32_t>(info.matrix->columns()); i++) {
- columns.Push(ctx.dst->IndexAccessor(arr, u32(i)));
- }
- ctx.dst->Func(name,
- utils::Vector{
- ctx.dst->Param(arr, array()),
- },
- matrix(),
- utils::Vector{
- ctx.dst->Return(ctx.dst->Construct(matrix(), columns)),
- });
- return name;
- });
- return ctx.dst->Call(fn, ctx.CloneWithoutTransform(expr));
}
return nullptr;
});
ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform