tint/transform: Refactor transforms

Replace the ShouldRun() method with Apply() which will do the
transformation if it needs to be done, otherwise returns
'SkipTransform'.

This reduces a bunch of duplicated scanning between the old ShouldRun()
and Transform().

This change also adjusts code style to make the transforms more
consistent.

Change-Id: I9a6b10cb8b4ed62676b12ef30fb7764d363386c6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/107681
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
diff --git a/src/tint/transform/calculate_array_length.cc b/src/tint/transform/calculate_array_length.cc
index 2ca5e54..9dcdd7b 100644
--- a/src/tint/transform/calculate_array_length.cc
+++ b/src/tint/transform/calculate_array_length.cc
@@ -40,6 +40,19 @@
 
 namespace {
 
+bool ShouldRun(const Program* program) {
+    for (auto* fn : program->AST().Functions()) {
+        if (auto* sem_fn = program->Sem().Get(fn)) {
+            for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
+                if (builtin->Type() == sem::BuiltinType::kArrayLength) {
+                    return true;
+                }
+            }
+        }
+    }
+    return false;
+}
+
 /// ArrayUsage describes a runtime array usage.
 /// It is used as a key by the array_length_by_usage map.
 struct ArrayUsage {
@@ -73,21 +86,16 @@
 CalculateArrayLength::CalculateArrayLength() = default;
 CalculateArrayLength::~CalculateArrayLength() = default;
 
-bool CalculateArrayLength::ShouldRun(const Program* program, const DataMap&) const {
-    for (auto* fn : program->AST().Functions()) {
-        if (auto* sem_fn = program->Sem().Get(fn)) {
-            for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
-                if (builtin->Type() == sem::BuiltinType::kArrayLength) {
-                    return true;
-                }
-            }
-        }
+Transform::ApplyResult CalculateArrayLength::Apply(const Program* src,
+                                                   const DataMap&,
+                                                   DataMap&) const {
+    if (!ShouldRun(src)) {
+        return SkipTransform;
     }
-    return false;
-}
 
-void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
-    auto& sem = ctx.src->Sem();
+    ProgramBuilder b;
+    CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+    auto& sem = src->Sem();
 
     // get_buffer_size_intrinsic() emits the function decorated with
     // BufferSizeIntrinsic that is transformed by the HLSL writer into a call to
@@ -95,24 +103,20 @@
     std::unordered_map<const sem::Reference*, Symbol> buffer_size_intrinsics;
     auto get_buffer_size_intrinsic = [&](const sem::Reference* buffer_type) {
         return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] {
-            auto name = ctx.dst->Sym();
+            auto name = b.Sym();
             auto* type = CreateASTTypeFor(ctx, buffer_type);
-            auto* disable_validation =
-                ctx.dst->Disable(ast::DisabledValidation::kFunctionParameter);
-            ctx.dst->AST().AddFunction(ctx.dst->create<ast::Function>(
+            auto* disable_validation = b.Disable(ast::DisabledValidation::kFunctionParameter);
+            b.AST().AddFunction(b.create<ast::Function>(
                 name,
                 utils::Vector{
-                    ctx.dst->Param("buffer",
-                                   ctx.dst->ty.pointer(type, buffer_type->AddressSpace(),
-                                                       buffer_type->Access()),
-                                   utils::Vector{disable_validation}),
-                    ctx.dst->Param("result", ctx.dst->ty.pointer(ctx.dst->ty.u32(),
-                                                                 ast::AddressSpace::kFunction)),
+                    b.Param("buffer",
+                            b.ty.pointer(type, buffer_type->AddressSpace(), buffer_type->Access()),
+                            utils::Vector{disable_validation}),
+                    b.Param("result", b.ty.pointer(b.ty.u32(), ast::AddressSpace::kFunction)),
                 },
-                ctx.dst->ty.void_(), nullptr,
+                b.ty.void_(), nullptr,
                 utils::Vector{
-                    ctx.dst->ASTNodes().Create<BufferSizeIntrinsic>(ctx.dst->ID(),
-                                                                    ctx.dst->AllocateNodeID()),
+                    b.ASTNodes().Create<BufferSizeIntrinsic>(b.ID(), b.AllocateNodeID()),
                 },
                 utils::Empty));
 
@@ -123,7 +127,7 @@
     std::unordered_map<ArrayUsage, Symbol, ArrayUsage::Hasher> array_length_by_usage;
 
     // Find all the arrayLength() calls...
-    for (auto* node : ctx.src->ASTNodes().Objects()) {
+    for (auto* node : src->ASTNodes().Objects()) {
         if (auto* call_expr = node->As<ast::CallExpression>()) {
             auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>();
             if (auto* builtin = call->Target()->As<sem::Builtin>()) {
@@ -149,7 +153,7 @@
                     auto* arg = call_expr->args[0];
                     auto* address_of = arg->As<ast::UnaryOpExpression>();
                     if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
-                        TINT_ICE(Transform, ctx.dst->Diagnostics())
+                        TINT_ICE(Transform, b.Diagnostics())
                             << "arrayLength() expected address-of, got " << arg->TypeInfo().name;
                     }
                     auto* storage_buffer_expr = address_of->expr;
@@ -158,7 +162,7 @@
                     }
                     auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
                     if (!storage_buffer_sem) {
-                        TINT_ICE(Transform, ctx.dst->Diagnostics())
+                        TINT_ICE(Transform, b.Diagnostics())
                             << "expected form of arrayLength argument to be &array_var or "
                                "&struct_var.array_member";
                         break;
@@ -179,25 +183,24 @@
 
                             // Construct the variable that'll hold the result of
                             // RWByteAddressBuffer.GetDimensions()
-                            auto* buffer_size_result = ctx.dst->Decl(ctx.dst->Var(
-                                ctx.dst->Sym(), ctx.dst->ty.u32(), ctx.dst->Expr(0_u)));
+                            auto* buffer_size_result =
+                                b.Decl(b.Var(b.Sym(), b.ty.u32(), b.Expr(0_u)));
 
                             // Call storage_buffer.GetDimensions(&buffer_size_result)
-                            auto* call_get_dims = ctx.dst->CallStmt(ctx.dst->Call(
+                            auto* call_get_dims = b.CallStmt(b.Call(
                                 // BufferSizeIntrinsic(X, ARGS...) is
                                 // translated to:
                                 //  X.GetDimensions(ARGS..) by the writer
-                                buffer_size, ctx.dst->AddressOf(ctx.Clone(storage_buffer_expr)),
-                                ctx.dst->AddressOf(
-                                    ctx.dst->Expr(buffer_size_result->variable->symbol))));
+                                buffer_size, b.AddressOf(ctx.Clone(storage_buffer_expr)),
+                                b.AddressOf(b.Expr(buffer_size_result->variable->symbol))));
 
                             // Calculate actual array length
                             //                total_storage_buffer_size - array_offset
                             // array_length = ----------------------------------------
                             //                             array_stride
-                            auto name = ctx.dst->Sym();
+                            auto name = b.Sym();
                             const ast::Expression* total_size =
-                                ctx.dst->Expr(buffer_size_result->variable);
+                                b.Expr(buffer_size_result->variable);
 
                             const sem::Array* array_type = Switch(
                                 storage_buffer_type->StoreType(),
@@ -205,23 +208,21 @@
                                     // The variable is a struct, so subtract the byte offset of
                                     // the array member.
                                     auto* array_member_sem = str->Members().back();
-                                    total_size =
-                                        ctx.dst->Sub(total_size, u32(array_member_sem->Offset()));
+                                    total_size = b.Sub(total_size, u32(array_member_sem->Offset()));
                                     return array_member_sem->Type()->As<sem::Array>();
                                 },
                                 [&](const sem::Array* arr) { return arr; });
 
                             if (!array_type) {
-                                TINT_ICE(Transform, ctx.dst->Diagnostics())
+                                TINT_ICE(Transform, b.Diagnostics())
                                     << "expected form of arrayLength argument to be "
                                        "&array_var or &struct_var.array_member";
                                 return name;
                             }
 
                             uint32_t array_stride = array_type->Size();
-                            auto* array_length_var = ctx.dst->Decl(
-                                ctx.dst->Let(name, ctx.dst->ty.u32(),
-                                             ctx.dst->Div(total_size, u32(array_stride))));
+                            auto* array_length_var = b.Decl(
+                                b.Let(name, b.ty.u32(), b.Div(total_size, u32(array_stride))));
 
                             // Insert the array length calculations at the top of the block
                             ctx.InsertBefore(block->statements, block->statements[0],
@@ -234,13 +235,14 @@
                         });
 
                     // Replace the call to arrayLength() with the array length variable
-                    ctx.Replace(call_expr, ctx.dst->Expr(array_length));
+                    ctx.Replace(call_expr, b.Expr(array_length));
                 }
             }
         }
     }
 
     ctx.Clone();
+    return Program(std::move(b));
 }
 
 }  // namespace tint::transform