[ir] Update style of ExpandImplicitSplats

Make it match other transforms. Also remove the worklist, since the
instructions iterator is not affected by additions to the instruction
list.

Change-Id: I5de3f54e07c982f51ab3bf29fac814001fb07da6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/186281
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/spirv/writer/raise/expand_implicit_splats.cc b/src/tint/lang/spirv/writer/raise/expand_implicit_splats.cc
index 0a6ba91..3b650a7 100644
--- a/src/tint/lang/spirv/writer/raise/expand_implicit_splats.cc
+++ b/src/tint/lang/spirv/writer/raise/expand_implicit_splats.cc
@@ -41,50 +41,56 @@
 
 namespace {
 
-void Run(core::ir::Module& ir) {
+/// PIMPL state for the transform.
+struct State {
+    /// The IR module.
+    core::ir::Module& ir;
+
+    /// The IR builder.
     core::ir::Builder b{ir};
 
-    // Find the instructions that use implicit splats and either modify them in place or record them
-    // to be replaced in a second pass.
-    Vector<core::ir::CoreBinary*, 4> binary_worklist;
-    Vector<core::ir::CoreBuiltinCall*, 4> builtin_worklist;
-    for (auto* inst : ir.Instructions()) {
-        if (auto* construct = inst->As<core::ir::Construct>()) {
-            // A vector constructor with a single scalar argument needs to be modified to replicate
-            // the argument N times.
-            auto* vec = construct->Result(0)->Type()->As<core::type::Vector>();
-            if (vec &&  //
-                construct->Args().Length() == 1 &&
-                construct->Args()[0]->Type()->Is<core::type::Scalar>()) {
-                for (uint32_t i = 1; i < vec->Width(); i++) {
-                    construct->AppendArg(construct->Args()[0]);
+    /// Process the module.
+    void Process() {
+        // Find the instructions that use implicit splats and modify or replace them.
+        for (auto* inst : ir.Instructions()) {
+            if (auto* construct = inst->As<core::ir::Construct>()) {
+                // A vector constructor with a single scalar argument needs to be modified to
+                // replicate the argument N times.
+                auto* vec = construct->Result(0)->Type()->As<core::type::Vector>();
+                if (vec &&  //
+                    construct->Args().Length() == 1 &&
+                    construct->Args()[0]->Type()->Is<core::type::Scalar>()) {
+                    for (uint32_t i = 1; i < vec->Width(); i++) {
+                        construct->AppendArg(construct->Args()[0]);
+                    }
                 }
-            }
-        } else if (auto* binary = inst->As<core::ir::CoreBinary>()) {
-            // A binary instruction that mixes vector and scalar operands needs to have the scalar
-            // operand replaced with an explicit vector constructor.
-            if (binary->Result(0)->Type()->Is<core::type::Vector>()) {
-                if (binary->LHS()->Type()->Is<core::type::Scalar>() ||
-                    binary->RHS()->Type()->Is<core::type::Scalar>()) {
-                    binary_worklist.Push(binary);
+            } else if (auto* binary = inst->As<core::ir::CoreBinary>()) {
+                // A binary instruction that mixes vector and scalar operands needs to have the
+                // scalar operand replaced with an explicit vector constructor.
+                if (binary->Result(0)->Type()->Is<core::type::Vector>()) {
+                    if (binary->LHS()->Type()->Is<core::type::Scalar>() ||
+                        binary->RHS()->Type()->Is<core::type::Scalar>()) {
+                        ExpandBinary(binary);
+                    }
                 }
-            }
-        } else if (auto* builtin = inst->As<core::ir::CoreBuiltinCall>()) {
-            // A mix builtin call that mixes vector and scalar operands needs to have the scalar
-            // operand replaced with an explicit vector constructor.
-            if (builtin->Func() == core::BuiltinFn::kMix) {
-                if (builtin->Result(0)->Type()->Is<core::type::Vector>()) {
-                    if (builtin->Args()[2]->Type()->Is<core::type::Scalar>()) {
-                        builtin_worklist.Push(builtin);
+            } else if (auto* builtin = inst->As<core::ir::CoreBuiltinCall>()) {
+                // A mix builtin call that mixes vector and scalar operands needs to have the scalar
+                // operand replaced with an explicit vector constructor.
+                if (builtin->Func() == core::BuiltinFn::kMix) {
+                    if (builtin->Result(0)->Type()->Is<core::type::Vector>()) {
+                        if (builtin->Args()[2]->Type()->Is<core::type::Scalar>()) {
+                            ExpandOperand(builtin,
+                                          core::ir::CoreBuiltinCall::kArgsOperandOffset + 2);
+                        }
                     }
                 }
             }
         }
     }
 
-    // Helper to expand a scalar operand of an instruction by replacing it with an explicitly
-    // constructed vector that matches the result type.
-    auto expand_operand = [&](core::ir::Instruction* inst, size_t operand_idx) {
+    /// Helper to expand a scalar operand of an instruction by replacing it with an explicitly
+    /// constructed vector that matches the result type.
+    void ExpandOperand(core::ir::Instruction* inst, size_t operand_idx) {
         auto* vec = inst->Result(0)->Type()->As<core::type::Vector>();
 
         Vector<core::ir::Value*, 4> args;
@@ -93,10 +99,11 @@
         auto* construct = b.Construct(vec, std::move(args));
         construct->InsertBefore(inst);
         inst->SetOperand(operand_idx, construct->Result(0));
-    };
+    }
 
-    // Replace scalar operands to binary instructions that produce vectors.
-    for (auto* binary : binary_worklist) {
+    /// Replace scalar operands to binary instructions that produce vectors.
+    /// @param binary the binary instruction to modify
+    void ExpandBinary(core::ir::Binary* binary) {
         auto* result_ty = binary->Result(0)->Type();
         if (result_ty->is_float_vector() && binary->Op() == core::BinaryOp::kMultiply) {
             // Use OpVectorTimesScalar for floating point multiply.
@@ -118,26 +125,13 @@
         } else {
             // Expand the scalar argument into an explicitly constructed vector.
             if (binary->LHS()->Type()->Is<core::type::Scalar>()) {
-                expand_operand(binary, core::ir::CoreBinary::kLhsOperandOffset);
+                ExpandOperand(binary, core::ir::CoreBinary::kLhsOperandOffset);
             } else if (binary->RHS()->Type()->Is<core::type::Scalar>()) {
-                expand_operand(binary, core::ir::CoreBinary::kRhsOperandOffset);
+                ExpandOperand(binary, core::ir::CoreBinary::kRhsOperandOffset);
             }
         }
     }
-
-    // Replace scalar arguments to builtin calls that produce vectors.
-    for (auto* builtin : builtin_worklist) {
-        switch (builtin->Func()) {
-            case core::BuiltinFn::kMix:
-                // Expand the scalar argument into an explicitly constructed vector.
-                expand_operand(builtin, core::ir::CoreBuiltinCall::kArgsOperandOffset + 2);
-                break;
-            default:
-                TINT_UNREACHABLE() << "unhandled builtin call";
-                break;
-        }
-    }
-}
+};
 
 }  // namespace
 
@@ -147,7 +141,7 @@
         return result.Failure();
     }
 
-    Run(ir);
+    State{ir}.Process();
 
     return Success;
 }