[tint][ir] Split BreakIf arguments into two lists
One for the next-iteration.
One for the loop exit.
Validation will be done as a followup CL.
Change-Id: Iee36c9043b70e9b7ffbc92ef1f28e77d9de996f9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/187686
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/lang/core/ir/binary/decode.cc b/src/tint/lang/core/ir/binary/decode.cc
index eb320b2..c422433 100644
--- a/src/tint/lang/core/ir/binary/decode.cc
+++ b/src/tint/lang/core/ir/binary/decode.cc
@@ -341,6 +341,11 @@
}
inst_out->SetResults(std::move(results));
+ if (inst_in.has_break_if()) {
+ static_cast<BreakIf*>(inst_out)->SetNumNextIterValues(
+ inst_in.break_if().num_next_iter_values());
+ }
+
return inst_out;
}
diff --git a/src/tint/lang/core/ir/binary/encode.cc b/src/tint/lang/core/ir/binary/encode.cc
index 70305bd..74d00ac 100644
--- a/src/tint/lang/core/ir/binary/encode.cc
+++ b/src/tint/lang/core/ir/binary/encode.cc
@@ -253,7 +253,10 @@
void InstructionBitcast(pb::InstructionBitcast&, const ir::Bitcast*) {}
- void InstructionBreakIf(pb::InstructionBreakIf&, const ir::BreakIf*) {}
+ void InstructionBreakIf(pb::InstructionBreakIf& breakif_out, const ir::BreakIf* breakif_in) {
+ auto num_next_iter_values = static_cast<uint32_t>(breakif_in->NextIterValues().Length());
+ breakif_out.set_num_next_iter_values(num_next_iter_values);
+ }
void InstructionBuiltinCall(pb::InstructionBuiltinCall& call_out,
const ir::CoreBuiltinCall* call_in) {
diff --git a/src/tint/lang/core/ir/binary/ir.proto b/src/tint/lang/core/ir/binary/ir.proto
index 19f3b0a..31c2821 100644
--- a/src/tint/lang/core/ir/binary/ir.proto
+++ b/src/tint/lang/core/ir/binary/ir.proto
@@ -354,7 +354,9 @@
message InstructionContinue {}
-message InstructionBreakIf {}
+message InstructionBreakIf {
+ uint32 num_next_iter_values = 1;
+}
message InstructionUnreachable {}
diff --git a/src/tint/lang/core/ir/binary/roundtrip_test.cc b/src/tint/lang/core/ir/binary/roundtrip_test.cc
index 52a516b..76f92b8 100644
--- a/src/tint/lang/core/ir/binary/roundtrip_test.cc
+++ b/src/tint/lang/core/ir/binary/roundtrip_test.cc
@@ -698,7 +698,10 @@
TEST_F(IRBinaryRoundtripTest, LoopBlockParams) {
auto* fn = b.Function("Function", ty.void_());
b.Append(fn->Block(), [&] {
+ auto* loop_res_a = b.InstructionResult(ty.i32());
+ auto* loop_res_b = b.InstructionResult(ty.f32());
auto* loop = b.Loop();
+ loop->SetResults(Vector{loop_res_a, loop_res_b});
b.Append(loop->Initializer(), [&] {
b.Let("L", 1_i);
b.NextIteration(loop);
@@ -710,7 +713,12 @@
auto* z = b.BlockParam<u32>("z");
auto* w = b.BlockParam<bool>("w");
loop->Continuing()->SetParams({z, w});
- b.Append(loop->Continuing(), [&] { b.BreakIf(loop, false, 3_i, 4_f); });
+ b.Append(loop->Continuing(), [&] {
+ b.BreakIf(loop,
+ /* condition */ false,
+ /* next iter */ b.Values(3_i, 4_f),
+ /* exit */ b.Values(5_u, 6_i));
+ });
b.Return(fn);
});
RUN_TEST();
diff --git a/src/tint/lang/core/ir/break_if.cc b/src/tint/lang/core/ir/break_if.cc
index 53a7145..efb2990 100644
--- a/src/tint/lang/core/ir/break_if.cc
+++ b/src/tint/lang/core/ir/break_if.cc
@@ -42,11 +42,16 @@
BreakIf::BreakIf() = default;
-BreakIf::BreakIf(Value* condition, ir::Loop* loop, VectorRef<Value*> args) : loop_(loop) {
+BreakIf::BreakIf(Value* condition,
+ ir::Loop* loop,
+ VectorRef<Value*> next_iter_values /* = tint::Empty */,
+ VectorRef<Value*> exit_values /* = tint::Empty */)
+ : loop_(loop), num_next_iter_values_(next_iter_values.Length()) {
TINT_ASSERT(loop_);
AddOperand(BreakIf::kConditionOperandOffset, condition);
- AddOperands(BreakIf::kArgsOperandOffset, std::move(args));
+ AddOperands(BreakIf::kArgsOperandOffset, std::move(next_iter_values));
+ AddOperands(BreakIf::kArgsOperandOffset + num_next_iter_values_, std::move(exit_values));
if (loop_) {
loop_->Body()->AddInboundSiblingBranch(this);
diff --git a/src/tint/lang/core/ir/break_if.h b/src/tint/lang/core/ir/break_if.h
index 85508e2..090f22a 100644
--- a/src/tint/lang/core/ir/break_if.h
+++ b/src/tint/lang/core/ir/break_if.h
@@ -57,8 +57,14 @@
/// Constructor
/// @param condition the break condition
/// @param loop the loop containing the break-if
- /// @param args the MultiInBlock arguments
- BreakIf(Value* condition, ir::Loop* loop, VectorRef<Value*> args = tint::Empty);
+ /// @param next_iter_values the arguments passed to the loop body MultiInBlock, if the break
+ /// condition evaluates to `false`.
+ /// @param exit_values the values returned by the loop, if the break condition evaluates to
+ /// `true`.
+ BreakIf(Value* condition,
+ ir::Loop* loop,
+ VectorRef<Value*> next_iter_values = tint::Empty,
+ VectorRef<Value*> exit_values = tint::Empty);
~BreakIf() override;
/// @copydoc Instruction::Clone()
@@ -85,8 +91,39 @@
/// @returns the friendly name for the instruction
std::string FriendlyName() const override { return "break_if"; }
+ /// @returns the arguments passed to the loop body MultiInBlock, if the break condition
+ /// evaluates to `false`.
+ Slice<Value* const> NextIterValues() {
+ return operands_.Slice().Offset(kArgsOperandOffset).Truncate(num_next_iter_values_);
+ }
+
+ /// @returns the arguments passed to the loop body MultiInBlock, if the break condition
+ /// evaluates to `false`.
+ Slice<const Value* const> NextIterValues() const {
+ return operands_.Slice().Offset(kArgsOperandOffset).Truncate(num_next_iter_values_);
+ }
+
+ /// @returns the values returned by the loop, if the break condition evaluates to `true`.
+ Slice<Value* const> ExitValues() {
+ return operands_.Slice().Offset(kArgsOperandOffset + num_next_iter_values_);
+ }
+
+ /// @returns the values returned by the loop, if the break condition evaluates to `true`.
+ Slice<const Value* const> ExitValues() const {
+ return operands_.Slice().Offset(kArgsOperandOffset + num_next_iter_values_);
+ }
+
+ /// Sets the number of operands used as the next iterator values.
+ /// The first @p num operands after kArgsOperandOffset are used as next iterator values,
+ /// subsequent operators are used as exit values.
+ void SetNumNextIterValues(size_t num) {
+ TINT_ASSERT(operands_.Length() >= num + kArgsOperandOffset);
+ num_next_iter_values_ = num;
+ }
+
private:
ConstPropagatingPtr<ir::Loop> loop_;
+ size_t num_next_iter_values_ = 0;
};
} // namespace tint::core::ir
diff --git a/src/tint/lang/core/ir/builder.h b/src/tint/lang/core/ir/builder.h
index 84fc0ed..5fbbb14 100644
--- a/src/tint/lang/core/ir/builder.h
+++ b/src/tint/lang/core/ir/builder.h
@@ -1304,14 +1304,32 @@
/// Creates a loop break-if instruction
/// @param condition the break condition
/// @param loop the loop being iterated
- /// @param args the arguments for the target MultiInBlock
/// @returns the instruction
- template <typename CONDITION, typename... ARGS>
- ir::BreakIf* BreakIf(ir::Loop* loop, CONDITION&& condition, ARGS&&... args) {
- CheckForNonDeterministicEvaluation<CONDITION, ARGS...>();
+ template <typename CONDITION>
+ ir::BreakIf* BreakIf(ir::Loop* loop, CONDITION&& condition) {
+ CheckForNonDeterministicEvaluation<CONDITION>();
+ auto* cond_val = Value(std::forward<CONDITION>(condition));
+ return Append(ir.allocators.instructions.Create<ir::BreakIf>(cond_val, loop));
+ }
+
+ /// Creates a loop break-if instruction
+ /// @param condition the break condition
+ /// @param loop the loop being iterated
+ /// @param next_iter_values the arguments passed to the loop body MultiInBlock, if the break
+ /// condition evaluates to `false`.
+ /// @param exit_values the values returned by the loop, if the break condition evaluates to
+ /// `true`.
+ /// @returns the instruction
+ template <typename CONDITION, typename NEXT_ITER_VALUES, typename EXIT_VALUES>
+ ir::BreakIf* BreakIf(ir::Loop* loop,
+ CONDITION&& condition,
+ NEXT_ITER_VALUES&& next_iter_values,
+ EXIT_VALUES&& exit_values) {
+ CheckForNonDeterministicEvaluation<CONDITION, NEXT_ITER_VALUES, EXIT_VALUES>();
auto* cond_val = Value(std::forward<CONDITION>(condition));
return Append(ir.allocators.instructions.Create<ir::BreakIf>(
- cond_val, loop, Values(std::forward<ARGS>(args)...)));
+ cond_val, loop, Values(std::forward<NEXT_ITER_VALUES>(next_iter_values)),
+ Values(std::forward<EXIT_VALUES>(exit_values))));
}
/// Creates a continue instruction
diff --git a/src/tint/lang/core/ir/disassembly.cc b/src/tint/lang/core/ir/disassembly.cc
index df0cc4e..1a8a9a6 100644
--- a/src/tint/lang/core/ir/disassembly.cc
+++ b/src/tint/lang/core/ir/disassembly.cc
@@ -26,7 +26,10 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "src/tint/lang/core/ir/disassembly.h"
+
+#include <algorithm>
#include <memory>
+#include <optional>
#include <string_view>
#include "src//tint/lang/core/ir/unary.h"
@@ -577,6 +580,16 @@
}
}
+void Disassembly::EmitOperandList(const Instruction* inst, size_t start_index, size_t count) {
+ size_t n = std::min(count, inst->Operands().Length());
+ for (size_t i = start_index; i < n; i++) {
+ if (i != start_index) {
+ out_ << ", ";
+ }
+ EmitOperand(inst, i);
+ }
+}
+
void Disassembly::EmitIf(const If* if_) {
SourceMarker sm(this);
if (auto results = if_->Results(); !results.IsEmpty()) {
@@ -730,52 +743,74 @@
out_ << "}";
}
-void Disassembly::EmitTerminator(const Terminator* b) {
+void Disassembly::EmitTerminator(const Terminator* term) {
SourceMarker sm(this);
- size_t args_offset = 0;
- tint::Switch(
- b,
+ auto args_offset = tint::Switch<std::optional<size_t>>(
+ term,
[&](const ir::Return*) {
out_ << StyleInstruction("ret");
- args_offset = ir::Return::kArgsOperandOffset;
+ return ir::Return::kArgsOperandOffset;
},
[&](const ir::Continue*) {
out_ << StyleInstruction("continue");
- args_offset = ir::Continue::kArgsOperandOffset;
+ return ir::Continue::kArgsOperandOffset;
},
[&](const ir::ExitIf*) {
out_ << StyleInstruction("exit_if");
- args_offset = ir::ExitIf::kArgsOperandOffset;
+ return ir::ExitIf::kArgsOperandOffset;
},
[&](const ir::ExitSwitch*) {
out_ << StyleInstruction("exit_switch");
- args_offset = ir::ExitSwitch::kArgsOperandOffset;
+ return ir::ExitSwitch::kArgsOperandOffset;
},
[&](const ir::ExitLoop*) {
out_ << StyleInstruction("exit_loop");
- args_offset = ir::ExitLoop::kArgsOperandOffset;
+ return ir::ExitLoop::kArgsOperandOffset;
},
[&](const ir::NextIteration*) {
out_ << StyleInstruction("next_iteration");
- args_offset = ir::NextIteration::kArgsOperandOffset;
+ return ir::NextIteration::kArgsOperandOffset;
},
- [&](const ir::Unreachable*) { out_ << StyleInstruction("unreachable"); },
+ [&](const ir::Unreachable*) {
+ out_ << StyleInstruction("unreachable");
+ return std::nullopt;
+ },
[&](const ir::BreakIf* bi) {
- out_ << StyleInstruction("break_if") << " ";
+ out_ << StyleInstruction("break_if");
+ out_ << " ";
EmitValue(bi->Condition());
- args_offset = ir::BreakIf::kArgsOperandOffset;
+ auto next_iter_values = bi->NextIterValues();
+ auto exit_values = bi->ExitValues();
+ if (!next_iter_values.IsEmpty()) {
+ out_ << " " << StyleLabel("next_iteration") << ": [ ";
+ EmitOperandList(bi, ir::BreakIf::kArgsOperandOffset, next_iter_values.Length());
+ out_ << " ]";
+ }
+ if (!exit_values.IsEmpty()) {
+ out_ << " " << StyleLabel("exit_loop") << ": [ ";
+ EmitOperandList(bi, ir::BreakIf::kArgsOperandOffset + next_iter_values.Length());
+ out_ << " ]";
+ }
+ return std::nullopt;
},
- [&](const ir::TerminateInvocation*) { out_ << StyleInstruction("terminate_invocation"); },
- [&](Default) { out_ << StyleError("unknown terminator ", b->TypeInfo().name); });
+ [&](const ir::TerminateInvocation*) {
+ out_ << StyleInstruction("terminate_invocation");
+ return std::nullopt;
+ },
+ [&](Default) {
+ out_ << StyleError("unknown terminator ", term->TypeInfo().name);
+ return std::nullopt;
+ });
- if (!b->Args().IsEmpty()) {
+ if (args_offset && !term->Args().IsEmpty()) {
out_ << " ";
- EmitOperandList(b, args_offset);
+ EmitOperandList(term, *args_offset);
}
- sm.Store(b);
+
+ sm.Store(term);
tint::Switch(
- b, //
+ term, //
[&](const ir::BreakIf* bi) {
out_ << " "
<< StyleComment("# -> [t: exit_loop ", NameOf(bi->Loop()),
diff --git a/src/tint/lang/core/ir/disassembly.h b/src/tint/lang/core/ir/disassembly.h
index 70c48b4..99dfced 100644
--- a/src/tint/lang/core/ir/disassembly.h
+++ b/src/tint/lang/core/ir/disassembly.h
@@ -247,6 +247,7 @@
void EmitLine();
void EmitOperand(const Instruction* inst, size_t index);
void EmitOperandList(const Instruction* inst, size_t start_index = 0);
+ void EmitOperandList(const Instruction* inst, size_t start_index, size_t count);
void EmitInstructionName(const Instruction* inst);
const Module& mod_;
diff --git a/src/tint/lang/spirv/writer/loop_test.cc b/src/tint/lang/spirv/writer/loop_test.cc
index 633eb9b..1fb4f9e 100644
--- a/src/tint/lang/spirv/writer/loop_test.cc
+++ b/src/tint/lang/spirv/writer/loop_test.cc
@@ -358,7 +358,7 @@
loop->Continuing()->SetParams({cont_param});
b.Append(loop->Continuing(), [&] {
auto* cmp = b.GreaterThan(ty.bool_(), cont_param, 5_i);
- b.BreakIf(loop, cmp, cont_param);
+ b.BreakIf(loop, cmp, /* next_iter */ Vector{cont_param}, /* exit */ Empty);
});
b.Return(func);
@@ -467,7 +467,7 @@
loop->Continuing()->SetParams({cont_param});
b.Append(loop->Continuing(), [&] {
auto* cmp = b.GreaterThan(ty.bool_(), cont_param, 5_i);
- b.BreakIf(loop, cmp, cont_param);
+ b.BreakIf(loop, cmp, /* next_iter */ Vector{cont_param}, /* exit */ Empty);
});
b.Return(func);
@@ -533,7 +533,7 @@
outer->Continuing()->SetParams({cont_param});
b.Append(outer->Continuing(), [&] {
auto* cmp = b.GreaterThan(ty.bool_(), cont_param, 5_i);
- b.BreakIf(outer, cmp, cont_param);
+ b.BreakIf(outer, cmp, /* next_iter */ Vector{cont_param}, /* exit */ Empty);
});
b.Return(func);
diff --git a/src/tint/lang/spirv/writer/raise/merge_return_test.cc b/src/tint/lang/spirv/writer/raise/merge_return_test.cc
index 40d44b0..9cb2f0a 100644
--- a/src/tint/lang/spirv/writer/raise/merge_return_test.cc
+++ b/src/tint/lang/spirv/writer/raise/merge_return_test.cc
@@ -1698,7 +1698,7 @@
b.Append(loop->Continuing(), [&] {
b.Store(global, 1_i);
- b.BreakIf(loop, true, 4_i);
+ b.BreakIf(loop, true, /* next_iter */ b.Values(4_i), /* exit */ Empty);
});
b.Store(global, 3_i);