[ir] Update style of ExpandImplicitSplats
Make it match other transforms. Also remove the worklist, since the
instructions iterator is not affected by additions to the instruction
list.
Change-Id: I5de3f54e07c982f51ab3bf29fac814001fb07da6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/186281
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/spirv/writer/raise/expand_implicit_splats.cc b/src/tint/lang/spirv/writer/raise/expand_implicit_splats.cc
index 0a6ba91..3b650a7 100644
--- a/src/tint/lang/spirv/writer/raise/expand_implicit_splats.cc
+++ b/src/tint/lang/spirv/writer/raise/expand_implicit_splats.cc
@@ -41,50 +41,56 @@
namespace {
-void Run(core::ir::Module& ir) {
+/// PIMPL state for the transform.
+struct State {
+ /// The IR module.
+ core::ir::Module& ir;
+
+ /// The IR builder.
core::ir::Builder b{ir};
- // Find the instructions that use implicit splats and either modify them in place or record them
- // to be replaced in a second pass.
- Vector<core::ir::CoreBinary*, 4> binary_worklist;
- Vector<core::ir::CoreBuiltinCall*, 4> builtin_worklist;
- for (auto* inst : ir.Instructions()) {
- if (auto* construct = inst->As<core::ir::Construct>()) {
- // A vector constructor with a single scalar argument needs to be modified to replicate
- // the argument N times.
- auto* vec = construct->Result(0)->Type()->As<core::type::Vector>();
- if (vec && //
- construct->Args().Length() == 1 &&
- construct->Args()[0]->Type()->Is<core::type::Scalar>()) {
- for (uint32_t i = 1; i < vec->Width(); i++) {
- construct->AppendArg(construct->Args()[0]);
+ /// Process the module.
+ void Process() {
+ // Find the instructions that use implicit splats and modify or replace them.
+ for (auto* inst : ir.Instructions()) {
+ if (auto* construct = inst->As<core::ir::Construct>()) {
+ // A vector constructor with a single scalar argument needs to be modified to
+ // replicate the argument N times.
+ auto* vec = construct->Result(0)->Type()->As<core::type::Vector>();
+ if (vec && //
+ construct->Args().Length() == 1 &&
+ construct->Args()[0]->Type()->Is<core::type::Scalar>()) {
+ for (uint32_t i = 1; i < vec->Width(); i++) {
+ construct->AppendArg(construct->Args()[0]);
+ }
}
- }
- } else if (auto* binary = inst->As<core::ir::CoreBinary>()) {
- // A binary instruction that mixes vector and scalar operands needs to have the scalar
- // operand replaced with an explicit vector constructor.
- if (binary->Result(0)->Type()->Is<core::type::Vector>()) {
- if (binary->LHS()->Type()->Is<core::type::Scalar>() ||
- binary->RHS()->Type()->Is<core::type::Scalar>()) {
- binary_worklist.Push(binary);
+ } else if (auto* binary = inst->As<core::ir::CoreBinary>()) {
+ // A binary instruction that mixes vector and scalar operands needs to have the
+ // scalar operand replaced with an explicit vector constructor.
+ if (binary->Result(0)->Type()->Is<core::type::Vector>()) {
+ if (binary->LHS()->Type()->Is<core::type::Scalar>() ||
+ binary->RHS()->Type()->Is<core::type::Scalar>()) {
+ ExpandBinary(binary);
+ }
}
- }
- } else if (auto* builtin = inst->As<core::ir::CoreBuiltinCall>()) {
- // A mix builtin call that mixes vector and scalar operands needs to have the scalar
- // operand replaced with an explicit vector constructor.
- if (builtin->Func() == core::BuiltinFn::kMix) {
- if (builtin->Result(0)->Type()->Is<core::type::Vector>()) {
- if (builtin->Args()[2]->Type()->Is<core::type::Scalar>()) {
- builtin_worklist.Push(builtin);
+ } else if (auto* builtin = inst->As<core::ir::CoreBuiltinCall>()) {
+ // A mix builtin call that mixes vector and scalar operands needs to have the scalar
+ // operand replaced with an explicit vector constructor.
+ if (builtin->Func() == core::BuiltinFn::kMix) {
+ if (builtin->Result(0)->Type()->Is<core::type::Vector>()) {
+ if (builtin->Args()[2]->Type()->Is<core::type::Scalar>()) {
+ ExpandOperand(builtin,
+ core::ir::CoreBuiltinCall::kArgsOperandOffset + 2);
+ }
}
}
}
}
}
- // Helper to expand a scalar operand of an instruction by replacing it with an explicitly
- // constructed vector that matches the result type.
- auto expand_operand = [&](core::ir::Instruction* inst, size_t operand_idx) {
+ /// Helper to expand a scalar operand of an instruction by replacing it with an explicitly
+ /// constructed vector that matches the result type.
+ void ExpandOperand(core::ir::Instruction* inst, size_t operand_idx) {
auto* vec = inst->Result(0)->Type()->As<core::type::Vector>();
Vector<core::ir::Value*, 4> args;
@@ -93,10 +99,11 @@
auto* construct = b.Construct(vec, std::move(args));
construct->InsertBefore(inst);
inst->SetOperand(operand_idx, construct->Result(0));
- };
+ }
- // Replace scalar operands to binary instructions that produce vectors.
- for (auto* binary : binary_worklist) {
+ /// Replace scalar operands to binary instructions that produce vectors.
+ /// @param binary the binary instruction to modify
+ void ExpandBinary(core::ir::Binary* binary) {
auto* result_ty = binary->Result(0)->Type();
if (result_ty->is_float_vector() && binary->Op() == core::BinaryOp::kMultiply) {
// Use OpVectorTimesScalar for floating point multiply.
@@ -118,26 +125,13 @@
} else {
// Expand the scalar argument into an explicitly constructed vector.
if (binary->LHS()->Type()->Is<core::type::Scalar>()) {
- expand_operand(binary, core::ir::CoreBinary::kLhsOperandOffset);
+ ExpandOperand(binary, core::ir::CoreBinary::kLhsOperandOffset);
} else if (binary->RHS()->Type()->Is<core::type::Scalar>()) {
- expand_operand(binary, core::ir::CoreBinary::kRhsOperandOffset);
+ ExpandOperand(binary, core::ir::CoreBinary::kRhsOperandOffset);
}
}
}
-
- // Replace scalar arguments to builtin calls that produce vectors.
- for (auto* builtin : builtin_worklist) {
- switch (builtin->Func()) {
- case core::BuiltinFn::kMix:
- // Expand the scalar argument into an explicitly constructed vector.
- expand_operand(builtin, core::ir::CoreBuiltinCall::kArgsOperandOffset + 2);
- break;
- default:
- TINT_UNREACHABLE() << "unhandled builtin call";
- break;
- }
- }
-}
+};
} // namespace
@@ -147,7 +141,7 @@
return result.Failure();
}
- Run(ir);
+ State{ir}.Process();
return Success;
}