tint/transform: Refactor transforms

Replace the ShouldRun() method with Apply() which will do the
transformation if it needs to be done, otherwise returns
'SkipTransform'.

This reduces a bunch of duplicated scanning between the old ShouldRun()
and Transform().

This change also adjusts code style to make the transforms more
consistent.

Change-Id: I9a6b10cb8b4ed62676b12ef30fb7764d363386c6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/107681
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
diff --git a/src/tint/transform/decompose_strided_matrix.cc b/src/tint/transform/decompose_strided_matrix.cc
index 91aed43..5494ca2 100644
--- a/src/tint/transform/decompose_strided_matrix.cc
+++ b/src/tint/transform/decompose_strided_matrix.cc
@@ -53,24 +53,25 @@
     };
 };
 
-/// Return type of the callback function of GatherCustomStrideMatrixMembers
-enum GatherResult { kContinue, kStop };
+}  // namespace
 
-/// GatherCustomStrideMatrixMembers scans `program` for all matrix members of
-/// storage and uniform structs, which are of a matrix type, and have a custom
-/// matrix stride attribute. For each matrix member found, `callback` is called.
-/// `callback` is a function with the signature:
-///      GatherResult(const sem::StructMember* member,
-///                   sem::Matrix* matrix,
-///                   uint32_t stride)
-/// If `callback` return GatherResult::kStop, then the scanning will immediately
-/// terminate, and GatherCustomStrideMatrixMembers() will return, otherwise
-/// scanning will continue.
-template <typename F>
-void GatherCustomStrideMatrixMembers(const Program* program, F&& callback) {
-    for (auto* node : program->ASTNodes().Objects()) {
+DecomposeStridedMatrix::DecomposeStridedMatrix() = default;
+
+DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
+
+Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
+                                                     const DataMap&,
+                                                     DataMap&) const {
+    ProgramBuilder b;
+    CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+    // Scan the program for all storage and uniform structure matrix members with
+    // a custom stride attribute. Replace these matrices with an equivalent array,
+    // and populate the `decomposed` map with the members that have been replaced.
+    utils::Hashmap<const ast::StructMember*, MatrixInfo, 8> decomposed;
+    for (auto* node : src->ASTNodes().Objects()) {
         if (auto* str = node->As<ast::Struct>()) {
-            auto* str_ty = program->Sem().Get(str);
+            auto* str_ty = src->Sem().Get(str);
             if (!str_ty->UsedAs(ast::AddressSpace::kUniform) &&
                 !str_ty->UsedAs(ast::AddressSpace::kStorage)) {
                 continue;
@@ -89,46 +90,20 @@
                 if (matrix->ColumnStride() == stride) {
                     continue;
                 }
-                if (callback(member, matrix, stride) == GatherResult::kStop) {
-                    return;
-                }
+                // We've got ourselves a struct member of a matrix type with a custom
+                // stride. Replace this with an array of column vectors.
+                MatrixInfo info{stride, matrix};
+                auto* replacement =
+                    b.Member(member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst));
+                ctx.Replace(member->Declaration(), replacement);
+                decomposed.Add(member->Declaration(), info);
             }
         }
     }
-}
 
-}  // namespace
-
-DecomposeStridedMatrix::DecomposeStridedMatrix() = default;
-
-DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
-
-bool DecomposeStridedMatrix::ShouldRun(const Program* program, const DataMap&) const {
-    bool should_run = false;
-    GatherCustomStrideMatrixMembers(program,
-                                    [&](const sem::StructMember*, const sem::Matrix*, uint32_t) {
-                                        should_run = true;
-                                        return GatherResult::kStop;
-                                    });
-    return should_run;
-}
-
-void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
-    // Scan the program for all storage and uniform structure matrix members with
-    // a custom stride attribute. Replace these matrices with an equivalent array,
-    // and populate the `decomposed` map with the members that have been replaced.
-    std::unordered_map<const ast::StructMember*, MatrixInfo> decomposed;
-    GatherCustomStrideMatrixMembers(
-        ctx.src, [&](const sem::StructMember* member, const sem::Matrix* matrix, uint32_t stride) {
-            // We've got ourselves a struct member of a matrix type with a custom
-            // stride. Replace this with an array of column vectors.
-            MatrixInfo info{stride, matrix};
-            auto* replacement =
-                ctx.dst->Member(member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst));
-            ctx.Replace(member->Declaration(), replacement);
-            decomposed.emplace(member->Declaration(), info);
-            return GatherResult::kContinue;
-        });
+    if (decomposed.IsEmpty()) {
+        return SkipTransform;
+    }
 
     // For all expressions where a single matrix column vector was indexed, we can
     // preserve these without calling conversion functions.
@@ -136,12 +111,11 @@
     //   ssbo.mat[2] -> ssbo.mat[2]
     ctx.ReplaceAll(
         [&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* {
-            if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr->object)) {
-                auto it = decomposed.find(access->Member()->Declaration());
-                if (it != decomposed.end()) {
+            if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr->object)) {
+                if (decomposed.Contains(access->Member()->Declaration())) {
                     auto* obj = ctx.CloneWithoutTransform(expr->object);
                     auto* idx = ctx.Clone(expr->index);
-                    return ctx.dst->IndexAccessor(obj, idx);
+                    return b.IndexAccessor(obj, idx);
                 }
             }
             return nullptr;
@@ -154,39 +128,36 @@
     //   ssbo.mat = mat_to_arr(m)
     std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
     ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt) -> const ast::Statement* {
-        if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
-            auto it = decomposed.find(access->Member()->Declaration());
-            if (it == decomposed.end()) {
-                return nullptr;
+        if (auto* access = src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
+            if (auto* info = decomposed.Find(access->Member()->Declaration())) {
+                auto fn = utils::GetOrCreate(mat_to_arr, *info, [&] {
+                    auto name =
+                        b.Symbols().New("mat" + std::to_string(info->matrix->columns()) + "x" +
+                                        std::to_string(info->matrix->rows()) + "_stride_" +
+                                        std::to_string(info->stride) + "_to_arr");
+
+                    auto matrix = [&] { return CreateASTTypeFor(ctx, info->matrix); };
+                    auto array = [&] { return info->array(ctx.dst); };
+
+                    auto mat = b.Sym("m");
+                    utils::Vector<const ast::Expression*, 4> columns;
+                    for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) {
+                        columns.Push(b.IndexAccessor(mat, u32(i)));
+                    }
+                    b.Func(name,
+                           utils::Vector{
+                               b.Param(mat, matrix()),
+                           },
+                           array(),
+                           utils::Vector{
+                               b.Return(b.Construct(array(), columns)),
+                           });
+                    return name;
+                });
+                auto* lhs = ctx.CloneWithoutTransform(stmt->lhs);
+                auto* rhs = b.Call(fn, ctx.Clone(stmt->rhs));
+                return b.Assign(lhs, rhs);
             }
-            MatrixInfo info = it->second;
-            auto fn = utils::GetOrCreate(mat_to_arr, info, [&] {
-                auto name =
-                    ctx.dst->Symbols().New("mat" + std::to_string(info.matrix->columns()) + "x" +
-                                           std::to_string(info.matrix->rows()) + "_stride_" +
-                                           std::to_string(info.stride) + "_to_arr");
-
-                auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
-                auto array = [&] { return info.array(ctx.dst); };
-
-                auto mat = ctx.dst->Sym("m");
-                utils::Vector<const ast::Expression*, 4> columns;
-                for (uint32_t i = 0; i < static_cast<uint32_t>(info.matrix->columns()); i++) {
-                    columns.Push(ctx.dst->IndexAccessor(mat, u32(i)));
-                }
-                ctx.dst->Func(name,
-                              utils::Vector{
-                                  ctx.dst->Param(mat, matrix()),
-                              },
-                              array(),
-                              utils::Vector{
-                                  ctx.dst->Return(ctx.dst->Construct(array(), columns)),
-                              });
-                return name;
-            });
-            auto* lhs = ctx.CloneWithoutTransform(stmt->lhs);
-            auto* rhs = ctx.dst->Call(fn, ctx.Clone(stmt->rhs));
-            return ctx.dst->Assign(lhs, rhs);
         }
         return nullptr;
     });
@@ -196,41 +167,40 @@
     //   m = arr_to_mat(ssbo.mat)
     std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
     ctx.ReplaceAll([&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* {
-        if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr)) {
-            auto it = decomposed.find(access->Member()->Declaration());
-            if (it == decomposed.end()) {
-                return nullptr;
+        if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr)) {
+            if (auto* info = decomposed.Find(access->Member()->Declaration())) {
+                auto fn = utils::GetOrCreate(arr_to_mat, *info, [&] {
+                    auto name =
+                        b.Symbols().New("arr_to_mat" + std::to_string(info->matrix->columns()) +
+                                        "x" + std::to_string(info->matrix->rows()) + "_stride_" +
+                                        std::to_string(info->stride));
+
+                    auto matrix = [&] { return CreateASTTypeFor(ctx, info->matrix); };
+                    auto array = [&] { return info->array(ctx.dst); };
+
+                    auto arr = b.Sym("arr");
+                    utils::Vector<const ast::Expression*, 4> columns;
+                    for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) {
+                        columns.Push(b.IndexAccessor(arr, u32(i)));
+                    }
+                    b.Func(name,
+                           utils::Vector{
+                               b.Param(arr, array()),
+                           },
+                           matrix(),
+                           utils::Vector{
+                               b.Return(b.Construct(matrix(), columns)),
+                           });
+                    return name;
+                });
+                return b.Call(fn, ctx.CloneWithoutTransform(expr));
             }
-            MatrixInfo info = it->second;
-            auto fn = utils::GetOrCreate(arr_to_mat, info, [&] {
-                auto name = ctx.dst->Symbols().New(
-                    "arr_to_mat" + std::to_string(info.matrix->columns()) + "x" +
-                    std::to_string(info.matrix->rows()) + "_stride_" + std::to_string(info.stride));
-
-                auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
-                auto array = [&] { return info.array(ctx.dst); };
-
-                auto arr = ctx.dst->Sym("arr");
-                utils::Vector<const ast::Expression*, 4> columns;
-                for (uint32_t i = 0; i < static_cast<uint32_t>(info.matrix->columns()); i++) {
-                    columns.Push(ctx.dst->IndexAccessor(arr, u32(i)));
-                }
-                ctx.dst->Func(name,
-                              utils::Vector{
-                                  ctx.dst->Param(arr, array()),
-                              },
-                              matrix(),
-                              utils::Vector{
-                                  ctx.dst->Return(ctx.dst->Construct(matrix(), columns)),
-                              });
-                return name;
-            });
-            return ctx.dst->Call(fn, ctx.CloneWithoutTransform(expr));
         }
         return nullptr;
     });
 
     ctx.Clone();
+    return Program(std::move(b));
 }
 
 }  // namespace tint::transform