[tint][ir] Split BreakIf arguments into two lists

One for the next-iteration.
One for the loop exit.

Validation will be done as a followup CL.

Change-Id: Iee36c9043b70e9b7ffbc92ef1f28e77d9de996f9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/187686
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/lang/core/ir/binary/decode.cc b/src/tint/lang/core/ir/binary/decode.cc
index eb320b2..c422433 100644
--- a/src/tint/lang/core/ir/binary/decode.cc
+++ b/src/tint/lang/core/ir/binary/decode.cc
@@ -341,6 +341,11 @@
         }
         inst_out->SetResults(std::move(results));
 
+        if (inst_in.has_break_if()) {
+            static_cast<BreakIf*>(inst_out)->SetNumNextIterValues(
+                inst_in.break_if().num_next_iter_values());
+        }
+
         return inst_out;
     }
 
diff --git a/src/tint/lang/core/ir/binary/encode.cc b/src/tint/lang/core/ir/binary/encode.cc
index 70305bd..74d00ac 100644
--- a/src/tint/lang/core/ir/binary/encode.cc
+++ b/src/tint/lang/core/ir/binary/encode.cc
@@ -253,7 +253,10 @@
 
     void InstructionBitcast(pb::InstructionBitcast&, const ir::Bitcast*) {}
 
-    void InstructionBreakIf(pb::InstructionBreakIf&, const ir::BreakIf*) {}
+    void InstructionBreakIf(pb::InstructionBreakIf& breakif_out, const ir::BreakIf* breakif_in) {
+        auto num_next_iter_values = static_cast<uint32_t>(breakif_in->NextIterValues().Length());
+        breakif_out.set_num_next_iter_values(num_next_iter_values);
+    }
 
     void InstructionBuiltinCall(pb::InstructionBuiltinCall& call_out,
                                 const ir::CoreBuiltinCall* call_in) {
diff --git a/src/tint/lang/core/ir/binary/ir.proto b/src/tint/lang/core/ir/binary/ir.proto
index 19f3b0a..31c2821 100644
--- a/src/tint/lang/core/ir/binary/ir.proto
+++ b/src/tint/lang/core/ir/binary/ir.proto
@@ -354,7 +354,9 @@
 
 message InstructionContinue {}
 
-message InstructionBreakIf {}
+message InstructionBreakIf {
+    uint32 num_next_iter_values = 1;
+}
 
 message InstructionUnreachable {}
 
diff --git a/src/tint/lang/core/ir/binary/roundtrip_test.cc b/src/tint/lang/core/ir/binary/roundtrip_test.cc
index 52a516b..76f92b8 100644
--- a/src/tint/lang/core/ir/binary/roundtrip_test.cc
+++ b/src/tint/lang/core/ir/binary/roundtrip_test.cc
@@ -698,7 +698,10 @@
 TEST_F(IRBinaryRoundtripTest, LoopBlockParams) {
     auto* fn = b.Function("Function", ty.void_());
     b.Append(fn->Block(), [&] {
+        auto* loop_res_a = b.InstructionResult(ty.i32());
+        auto* loop_res_b = b.InstructionResult(ty.f32());
         auto* loop = b.Loop();
+        loop->SetResults(Vector{loop_res_a, loop_res_b});
         b.Append(loop->Initializer(), [&] {
             b.Let("L", 1_i);
             b.NextIteration(loop);
@@ -710,7 +713,12 @@
         auto* z = b.BlockParam<u32>("z");
         auto* w = b.BlockParam<bool>("w");
         loop->Continuing()->SetParams({z, w});
-        b.Append(loop->Continuing(), [&] { b.BreakIf(loop, false, 3_i, 4_f); });
+        b.Append(loop->Continuing(), [&] {
+            b.BreakIf(loop,
+                      /* condition */ false,
+                      /* next iter */ b.Values(3_i, 4_f),
+                      /* exit */ b.Values(5_u, 6_i));
+        });
         b.Return(fn);
     });
     RUN_TEST();
diff --git a/src/tint/lang/core/ir/break_if.cc b/src/tint/lang/core/ir/break_if.cc
index 53a7145..efb2990 100644
--- a/src/tint/lang/core/ir/break_if.cc
+++ b/src/tint/lang/core/ir/break_if.cc
@@ -42,11 +42,16 @@
 
 BreakIf::BreakIf() = default;
 
-BreakIf::BreakIf(Value* condition, ir::Loop* loop, VectorRef<Value*> args) : loop_(loop) {
+BreakIf::BreakIf(Value* condition,
+                 ir::Loop* loop,
+                 VectorRef<Value*> next_iter_values /* = tint::Empty */,
+                 VectorRef<Value*> exit_values /* = tint::Empty */)
+    : loop_(loop), num_next_iter_values_(next_iter_values.Length()) {
     TINT_ASSERT(loop_);
 
     AddOperand(BreakIf::kConditionOperandOffset, condition);
-    AddOperands(BreakIf::kArgsOperandOffset, std::move(args));
+    AddOperands(BreakIf::kArgsOperandOffset, std::move(next_iter_values));
+    AddOperands(BreakIf::kArgsOperandOffset + num_next_iter_values_, std::move(exit_values));
 
     if (loop_) {
         loop_->Body()->AddInboundSiblingBranch(this);
diff --git a/src/tint/lang/core/ir/break_if.h b/src/tint/lang/core/ir/break_if.h
index 85508e2..090f22a 100644
--- a/src/tint/lang/core/ir/break_if.h
+++ b/src/tint/lang/core/ir/break_if.h
@@ -57,8 +57,14 @@
     /// Constructor
     /// @param condition the break condition
     /// @param loop the loop containing the break-if
-    /// @param args the MultiInBlock arguments
-    BreakIf(Value* condition, ir::Loop* loop, VectorRef<Value*> args = tint::Empty);
+    /// @param next_iter_values the arguments passed to the loop body MultiInBlock, if the break
+    /// condition evaluates to `false`.
+    /// @param exit_values the values returned by the loop, if the break condition evaluates to
+    /// `true`.
+    BreakIf(Value* condition,
+            ir::Loop* loop,
+            VectorRef<Value*> next_iter_values = tint::Empty,
+            VectorRef<Value*> exit_values = tint::Empty);
     ~BreakIf() override;
 
     /// @copydoc Instruction::Clone()
@@ -85,8 +91,39 @@
     /// @returns the friendly name for the instruction
     std::string FriendlyName() const override { return "break_if"; }
 
+    /// @returns the arguments passed to the loop body MultiInBlock, if the break condition
+    /// evaluates to `false`.
+    Slice<Value* const> NextIterValues() {
+        return operands_.Slice().Offset(kArgsOperandOffset).Truncate(num_next_iter_values_);
+    }
+
+    /// @returns the arguments passed to the loop body MultiInBlock, if the break condition
+    /// evaluates to `false`.
+    Slice<const Value* const> NextIterValues() const {
+        return operands_.Slice().Offset(kArgsOperandOffset).Truncate(num_next_iter_values_);
+    }
+
+    /// @returns the values returned by the loop, if the break condition evaluates to `true`.
+    Slice<Value* const> ExitValues() {
+        return operands_.Slice().Offset(kArgsOperandOffset + num_next_iter_values_);
+    }
+
+    /// @returns the values returned by the loop, if the break condition evaluates to `true`.
+    Slice<const Value* const> ExitValues() const {
+        return operands_.Slice().Offset(kArgsOperandOffset + num_next_iter_values_);
+    }
+
+    /// Sets the number of operands used as the next iterator values.
+    /// The first @p num operands after kArgsOperandOffset are used as next iterator values,
+    /// subsequent operators are used as exit values.
+    void SetNumNextIterValues(size_t num) {
+        TINT_ASSERT(operands_.Length() >= num + kArgsOperandOffset);
+        num_next_iter_values_ = num;
+    }
+
   private:
     ConstPropagatingPtr<ir::Loop> loop_;
+    size_t num_next_iter_values_ = 0;
 };
 
 }  // namespace tint::core::ir
diff --git a/src/tint/lang/core/ir/builder.h b/src/tint/lang/core/ir/builder.h
index 84fc0ed..5fbbb14 100644
--- a/src/tint/lang/core/ir/builder.h
+++ b/src/tint/lang/core/ir/builder.h
@@ -1304,14 +1304,32 @@
     /// Creates a loop break-if instruction
     /// @param condition the break condition
     /// @param loop the loop being iterated
-    /// @param args the arguments for the target MultiInBlock
     /// @returns the instruction
-    template <typename CONDITION, typename... ARGS>
-    ir::BreakIf* BreakIf(ir::Loop* loop, CONDITION&& condition, ARGS&&... args) {
-        CheckForNonDeterministicEvaluation<CONDITION, ARGS...>();
+    template <typename CONDITION>
+    ir::BreakIf* BreakIf(ir::Loop* loop, CONDITION&& condition) {
+        CheckForNonDeterministicEvaluation<CONDITION>();
+        auto* cond_val = Value(std::forward<CONDITION>(condition));
+        return Append(ir.allocators.instructions.Create<ir::BreakIf>(cond_val, loop));
+    }
+
+    /// Creates a loop break-if instruction
+    /// @param condition the break condition
+    /// @param loop the loop being iterated
+    /// @param next_iter_values the arguments passed to the loop body MultiInBlock, if the break
+    /// condition evaluates to `false`.
+    /// @param exit_values the values returned by the loop, if the break condition evaluates to
+    /// `true`.
+    /// @returns the instruction
+    template <typename CONDITION, typename NEXT_ITER_VALUES, typename EXIT_VALUES>
+    ir::BreakIf* BreakIf(ir::Loop* loop,
+                         CONDITION&& condition,
+                         NEXT_ITER_VALUES&& next_iter_values,
+                         EXIT_VALUES&& exit_values) {
+        CheckForNonDeterministicEvaluation<CONDITION, NEXT_ITER_VALUES, EXIT_VALUES>();
         auto* cond_val = Value(std::forward<CONDITION>(condition));
         return Append(ir.allocators.instructions.Create<ir::BreakIf>(
-            cond_val, loop, Values(std::forward<ARGS>(args)...)));
+            cond_val, loop, Values(std::forward<NEXT_ITER_VALUES>(next_iter_values)),
+            Values(std::forward<EXIT_VALUES>(exit_values))));
     }
 
     /// Creates a continue instruction
diff --git a/src/tint/lang/core/ir/disassembly.cc b/src/tint/lang/core/ir/disassembly.cc
index df0cc4e..1a8a9a6 100644
--- a/src/tint/lang/core/ir/disassembly.cc
+++ b/src/tint/lang/core/ir/disassembly.cc
@@ -26,7 +26,10 @@
 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 #include "src/tint/lang/core/ir/disassembly.h"
+
+#include <algorithm>
 #include <memory>
+#include <optional>
 #include <string_view>
 
 #include "src//tint/lang/core/ir/unary.h"
@@ -577,6 +580,16 @@
     }
 }
 
+void Disassembly::EmitOperandList(const Instruction* inst, size_t start_index, size_t count) {
+    size_t n = std::min(count, inst->Operands().Length());
+    for (size_t i = start_index; i < n; i++) {
+        if (i != start_index) {
+            out_ << ", ";
+        }
+        EmitOperand(inst, i);
+    }
+}
+
 void Disassembly::EmitIf(const If* if_) {
     SourceMarker sm(this);
     if (auto results = if_->Results(); !results.IsEmpty()) {
@@ -730,52 +743,74 @@
     out_ << "}";
 }
 
-void Disassembly::EmitTerminator(const Terminator* b) {
+void Disassembly::EmitTerminator(const Terminator* term) {
     SourceMarker sm(this);
-    size_t args_offset = 0;
-    tint::Switch(
-        b,
+    auto args_offset = tint::Switch<std::optional<size_t>>(
+        term,
         [&](const ir::Return*) {
             out_ << StyleInstruction("ret");
-            args_offset = ir::Return::kArgsOperandOffset;
+            return ir::Return::kArgsOperandOffset;
         },
         [&](const ir::Continue*) {
             out_ << StyleInstruction("continue");
-            args_offset = ir::Continue::kArgsOperandOffset;
+            return ir::Continue::kArgsOperandOffset;
         },
         [&](const ir::ExitIf*) {
             out_ << StyleInstruction("exit_if");
-            args_offset = ir::ExitIf::kArgsOperandOffset;
+            return ir::ExitIf::kArgsOperandOffset;
         },
         [&](const ir::ExitSwitch*) {
             out_ << StyleInstruction("exit_switch");
-            args_offset = ir::ExitSwitch::kArgsOperandOffset;
+            return ir::ExitSwitch::kArgsOperandOffset;
         },
         [&](const ir::ExitLoop*) {
             out_ << StyleInstruction("exit_loop");
-            args_offset = ir::ExitLoop::kArgsOperandOffset;
+            return ir::ExitLoop::kArgsOperandOffset;
         },
         [&](const ir::NextIteration*) {
             out_ << StyleInstruction("next_iteration");
-            args_offset = ir::NextIteration::kArgsOperandOffset;
+            return ir::NextIteration::kArgsOperandOffset;
         },
-        [&](const ir::Unreachable*) { out_ << StyleInstruction("unreachable"); },
+        [&](const ir::Unreachable*) {
+            out_ << StyleInstruction("unreachable");
+            return std::nullopt;
+        },
         [&](const ir::BreakIf* bi) {
-            out_ << StyleInstruction("break_if") << " ";
+            out_ << StyleInstruction("break_if");
+            out_ << " ";
             EmitValue(bi->Condition());
-            args_offset = ir::BreakIf::kArgsOperandOffset;
+            auto next_iter_values = bi->NextIterValues();
+            auto exit_values = bi->ExitValues();
+            if (!next_iter_values.IsEmpty()) {
+                out_ << " " << StyleLabel("next_iteration") << ": [ ";
+                EmitOperandList(bi, ir::BreakIf::kArgsOperandOffset, next_iter_values.Length());
+                out_ << " ]";
+            }
+            if (!exit_values.IsEmpty()) {
+                out_ << " " << StyleLabel("exit_loop") << ": [ ";
+                EmitOperandList(bi, ir::BreakIf::kArgsOperandOffset + next_iter_values.Length());
+                out_ << " ]";
+            }
+            return std::nullopt;
         },
-        [&](const ir::TerminateInvocation*) { out_ << StyleInstruction("terminate_invocation"); },
-        [&](Default) { out_ << StyleError("unknown terminator ", b->TypeInfo().name); });
+        [&](const ir::TerminateInvocation*) {
+            out_ << StyleInstruction("terminate_invocation");
+            return std::nullopt;
+        },
+        [&](Default) {
+            out_ << StyleError("unknown terminator ", term->TypeInfo().name);
+            return std::nullopt;
+        });
 
-    if (!b->Args().IsEmpty()) {
+    if (args_offset && !term->Args().IsEmpty()) {
         out_ << " ";
-        EmitOperandList(b, args_offset);
+        EmitOperandList(term, *args_offset);
     }
-    sm.Store(b);
+
+    sm.Store(term);
 
     tint::Switch(
-        b,  //
+        term,  //
         [&](const ir::BreakIf* bi) {
             out_ << "  "
                  << StyleComment("# -> [t: exit_loop ", NameOf(bi->Loop()),
diff --git a/src/tint/lang/core/ir/disassembly.h b/src/tint/lang/core/ir/disassembly.h
index 70c48b4..99dfced 100644
--- a/src/tint/lang/core/ir/disassembly.h
+++ b/src/tint/lang/core/ir/disassembly.h
@@ -247,6 +247,7 @@
     void EmitLine();
     void EmitOperand(const Instruction* inst, size_t index);
     void EmitOperandList(const Instruction* inst, size_t start_index = 0);
+    void EmitOperandList(const Instruction* inst, size_t start_index, size_t count);
     void EmitInstructionName(const Instruction* inst);
 
     const Module& mod_;
diff --git a/src/tint/lang/spirv/writer/loop_test.cc b/src/tint/lang/spirv/writer/loop_test.cc
index 633eb9b..1fb4f9e 100644
--- a/src/tint/lang/spirv/writer/loop_test.cc
+++ b/src/tint/lang/spirv/writer/loop_test.cc
@@ -358,7 +358,7 @@
         loop->Continuing()->SetParams({cont_param});
         b.Append(loop->Continuing(), [&] {
             auto* cmp = b.GreaterThan(ty.bool_(), cont_param, 5_i);
-            b.BreakIf(loop, cmp, cont_param);
+            b.BreakIf(loop, cmp, /* next_iter */ Vector{cont_param}, /* exit */ Empty);
         });
 
         b.Return(func);
@@ -467,7 +467,7 @@
         loop->Continuing()->SetParams({cont_param});
         b.Append(loop->Continuing(), [&] {
             auto* cmp = b.GreaterThan(ty.bool_(), cont_param, 5_i);
-            b.BreakIf(loop, cmp, cont_param);
+            b.BreakIf(loop, cmp, /* next_iter */ Vector{cont_param}, /* exit */ Empty);
         });
 
         b.Return(func);
@@ -533,7 +533,7 @@
         outer->Continuing()->SetParams({cont_param});
         b.Append(outer->Continuing(), [&] {
             auto* cmp = b.GreaterThan(ty.bool_(), cont_param, 5_i);
-            b.BreakIf(outer, cmp, cont_param);
+            b.BreakIf(outer, cmp, /* next_iter */ Vector{cont_param}, /* exit */ Empty);
         });
 
         b.Return(func);
diff --git a/src/tint/lang/spirv/writer/raise/merge_return_test.cc b/src/tint/lang/spirv/writer/raise/merge_return_test.cc
index 40d44b0..9cb2f0a 100644
--- a/src/tint/lang/spirv/writer/raise/merge_return_test.cc
+++ b/src/tint/lang/spirv/writer/raise/merge_return_test.cc
@@ -1698,7 +1698,7 @@
 
         b.Append(loop->Continuing(), [&] {
             b.Store(global, 1_i);
-            b.BreakIf(loop, true, 4_i);
+            b.BreakIf(loop, true, /* next_iter */ b.Values(4_i), /* exit */ Empty);
         });
 
         b.Store(global, 3_i);