[tint][ir] Serialize Loop instructions
Change-Id: I189ea4feea103742363bb57409c1d32d4ef898de
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/164886
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/lang/core/ir/binary/decode.cc b/src/tint/lang/core/ir/binary/decode.cc
index 33a00e9..2a790a6 100644
--- a/src/tint/lang/core/ir/binary/decode.cc
+++ b/src/tint/lang/core/ir/binary/decode.cc
@@ -54,6 +54,10 @@
Vector<ir::ExitIf*, 32> exit_ifs_{};
Vector<ir::ExitSwitch*, 32> exit_switches_{};
+ Vector<ir::ExitLoop*, 32> exit_loops_{};
+ Vector<ir::NextIteration*, 32> next_iterations_{};
+ Vector<ir::BreakIf*, 32> break_ifs_{};
+ Vector<ir::Continue*, 32> continues_{};
void Decode() {
{
@@ -110,6 +114,18 @@
for (auto* exit : exit_switches_) {
InferControlInstruction(exit, &ExitSwitch::SetSwitch);
}
+ for (auto* exit : exit_loops_) {
+ InferControlInstruction(exit, &ExitLoop::SetLoop);
+ }
+ for (auto* break_ifs : break_ifs_) {
+ InferControlInstruction(break_ifs, &BreakIf::SetLoop);
+ }
+ for (auto* next_iters : next_iterations_) {
+ InferControlInstruction(next_iters, &NextIteration::SetLoop);
+ }
+ for (auto* cont : continues_) {
+ InferControlInstruction(cont, &Continue::SetLoop);
+ }
}
template <typename EXIT, typename CTRL_INST>
@@ -174,9 +190,18 @@
////////////////////////////////////////////////////////////////////////////
// Blocks
////////////////////////////////////////////////////////////////////////////
- ir::Block* CreateBlock(const pb::Block&) { return b.Block(); }
+ ir::Block* CreateBlock(const pb::Block& block_in) {
+ return block_in.is_multi_in() ? b.MultiInBlock() : b.Block();
+ }
void PopulateBlock(ir::Block* block_out, const pb::Block& block_in) {
+ if (block_in.is_multi_in()) {
+ Vector<ir::BlockParam*, 8> params;
+ for (auto param : block_in.parameters()) {
+ params.Push(ValueAs<BlockParam>(param));
+ }
+ block_out->As<ir::MultiInBlock>()->SetParams(std::move(params));
+ }
for (auto& inst : block_in.instructions()) {
block_out->Append(Instruction(inst));
}
@@ -184,6 +209,17 @@
ir::Block* Block(uint32_t id) { return id > 0 ? blocks_[id - 1] : nullptr; }
+ template <typename T>
+ T* BlockAs(uint32_t id) {
+ auto* block = Block(id);
+ if (auto cast = block->As<T>(); TINT_LIKELY(cast)) {
+ return cast;
+ }
+ TINT_ICE() << "block " << id << " is " << (block ? block->TypeInfo().name : "<null>")
+ << " expected " << TypeInfo::Of<T>().name;
+ return nullptr;
+ }
+
////////////////////////////////////////////////////////////////////////////
// Instructions
////////////////////////////////////////////////////////////////////////////
@@ -196,18 +232,27 @@
case pb::Instruction::KindCase::kBinary:
inst_out = CreateInstructionBinary(inst_in.binary());
break;
+ case pb::Instruction::KindCase::kBreakIf:
+ inst_out = CreateInstructionBreakIf(inst_in.break_if());
+ break;
case pb::Instruction::KindCase::kBuiltinCall:
inst_out = CreateInstructionBuiltinCall(inst_in.builtin_call());
break;
case pb::Instruction::KindCase::kConstruct:
inst_out = CreateInstructionConstruct(inst_in.construct());
break;
+ case pb::Instruction::KindCase::kContinue:
+ inst_out = CreateInstructionContinue(inst_in.continue_());
+ break;
case pb::Instruction::KindCase::kConvert:
inst_out = CreateInstructionConvert(inst_in.convert());
break;
case pb::Instruction::KindCase::kExitIf:
inst_out = CreateInstructionExitIf(inst_in.exit_if());
break;
+ case pb::Instruction::KindCase::kExitLoop:
+ inst_out = CreateInstructionExitLoop(inst_in.exit_loop());
+ break;
case pb::Instruction::KindCase::kExitSwitch:
inst_out = CreateInstructionExitSwitch(inst_in.exit_switch());
break;
@@ -226,6 +271,12 @@
case pb::Instruction::KindCase::kLoadVectorElement:
inst_out = CreateInstructionLoadVectorElement(inst_in.load_vector_element());
break;
+ case pb::Instruction::KindCase::kLoop:
+ inst_out = CreateInstructionLoop(inst_in.loop());
+ break;
+ case pb::Instruction::KindCase::kNextIteration:
+ inst_out = CreateInstructionNextIteration(inst_in.next_iteration());
+ break;
case pb::Instruction::KindCase::kReturn:
inst_out = CreateInstructionReturn(inst_in.return_());
break;
@@ -281,6 +332,12 @@
return binary_out;
}
+ ir::BreakIf* CreateInstructionBreakIf(const pb::InstructionBreakIf&) {
+ auto* break_if_out = mod_out_.instructions.Create<ir::BreakIf>();
+ break_ifs_.Push(break_if_out);
+ return break_if_out;
+ }
+
ir::CoreBuiltinCall* CreateInstructionBuiltinCall(const pb::InstructionBuiltinCall& call_in) {
auto* call_out = mod_out_.instructions.Create<ir::CoreBuiltinCall>();
call_out->SetFunc(BuiltinFn(call_in.builtin()));
@@ -291,6 +348,12 @@
return mod_out_.instructions.Create<ir::Construct>();
}
+ ir::Continue* CreateInstructionContinue(const pb::InstructionContinue&) {
+ auto* continue_ = mod_out_.instructions.Create<ir::Continue>();
+ continues_.Push(continue_);
+ return continue_;
+ }
+
ir::Convert* CreateInstructionConvert(const pb::InstructionConvert&) {
return mod_out_.instructions.Create<ir::Convert>();
}
@@ -301,6 +364,12 @@
return exit_out;
}
+ ir::ExitLoop* CreateInstructionExitLoop(const pb::InstructionExitLoop&) {
+ auto* exit_out = mod_out_.instructions.Create<ir::ExitLoop>();
+ exit_loops_.Push(exit_out);
+ return exit_out;
+ }
+
ir::ExitSwitch* CreateInstructionExitSwitch(const pb::InstructionExitSwitch&) {
auto* exit_out = mod_out_.instructions.Create<ir::ExitSwitch>();
exit_switches_.Push(exit_out);
@@ -335,6 +404,28 @@
return mod_out_.instructions.Create<ir::LoadVectorElement>();
}
+ ir::Loop* CreateInstructionLoop(const pb::InstructionLoop& loop_in) {
+ auto* loop_out = mod_out_.instructions.Create<ir::Loop>();
+ if (loop_in.has_initalizer()) {
+ loop_out->SetInitializer(Block(loop_in.initalizer()));
+ } else {
+ loop_out->SetInitializer(mod_out_.blocks.Create());
+ }
+ loop_out->SetBody(BlockAs<ir::MultiInBlock>(loop_in.body()));
+ if (loop_in.has_continuing()) {
+ loop_out->SetContinuing(BlockAs<ir::MultiInBlock>(loop_in.continuing()));
+ } else {
+ loop_out->SetContinuing(mod_out_.blocks.Create<ir::MultiInBlock>());
+ }
+ return loop_out;
+ }
+
+ ir::NextIteration* CreateInstructionNextIteration(const pb::InstructionNextIteration&) {
+ auto* next_it_out = mod_out_.instructions.Create<ir::NextIteration>();
+ next_iterations_.Push(next_it_out);
+ return next_it_out;
+ }
+
ir::Return* CreateInstructionReturn(const pb::InstructionReturn&) {
return mod_out_.instructions.Create<ir::Return>();
}
@@ -496,9 +587,7 @@
InterpolationSampling(interpolation_in.sampling());
}
}
- if (attributes_in.has_invariant()) {
- attributes_out.invariant = attributes_in.invariant();
- }
+ attributes_out.invariant = attributes_in.invariant();
}
offset = RoundUp(align, offset);
auto* member_out = mod_out_.Types().Get<core::type::StructMember>(
@@ -551,6 +640,15 @@
}
break;
}
+ case pb::Value::KindCase::kBlockParameter: {
+ auto& param_in = value_in.block_parameter();
+ auto* type = Type(param_in.type());
+ value_out = b.BlockParam(type);
+ if (param_in.has_name()) {
+ mod_out_.SetName(value_out, param_in.name());
+ }
+ break;
+ }
case pb::Value::KindCase::kConstant: {
value_out = b.Constant(ConstantValue(value_in.constant()));
break;
diff --git a/src/tint/lang/core/ir/binary/encode.cc b/src/tint/lang/core/ir/binary/encode.cc
index c9ee44c..7656fd7 100644
--- a/src/tint/lang/core/ir/binary/encode.cc
+++ b/src/tint/lang/core/ir/binary/encode.cc
@@ -36,18 +36,24 @@
#include "src/tint/lang/core/constant/splat.h"
#include "src/tint/lang/core/ir/access.h"
#include "src/tint/lang/core/ir/binary.h"
+#include "src/tint/lang/core/ir/break_if.h"
#include "src/tint/lang/core/ir/construct.h"
+#include "src/tint/lang/core/ir/continue.h"
#include "src/tint/lang/core/ir/convert.h"
#include "src/tint/lang/core/ir/core_builtin_call.h"
#include "src/tint/lang/core/ir/discard.h"
#include "src/tint/lang/core/ir/exit_if.h"
+#include "src/tint/lang/core/ir/exit_loop.h"
#include "src/tint/lang/core/ir/exit_switch.h"
#include "src/tint/lang/core/ir/function_param.h"
#include "src/tint/lang/core/ir/if.h"
#include "src/tint/lang/core/ir/let.h"
#include "src/tint/lang/core/ir/load.h"
#include "src/tint/lang/core/ir/load_vector_element.h"
+#include "src/tint/lang/core/ir/loop.h"
#include "src/tint/lang/core/ir/module.h"
+#include "src/tint/lang/core/ir/multi_in_block.h"
+#include "src/tint/lang/core/ir/next_iteration.h"
#include "src/tint/lang/core/ir/return.h"
#include "src/tint/lang/core/ir/store.h"
#include "src/tint/lang/core/ir/store_vector_element.h"
@@ -148,6 +154,12 @@
for (auto* inst : *block_in) {
Instruction(*block_out.add_instructions(), inst);
}
+ if (auto* mib = block_in->As<ir::MultiInBlock>()) {
+ block_out.set_is_multi_in(true);
+ for (auto* param : mib->Params()) {
+ block_out.add_parameters(Value(param));
+ }
+ }
return id;
});
}
@@ -160,13 +172,16 @@
inst_in, //
[&](const ir::Access* i) { InstructionAccess(*inst_out.mutable_access(), i); },
[&](const ir::Binary* i) { InstructionBinary(*inst_out.mutable_binary(), i); },
+ [&](const ir::BreakIf* i) { InstructionBreakIf(*inst_out.mutable_break_if(), i); },
[&](const ir::CoreBuiltinCall* i) {
InstructionBuiltinCall(*inst_out.mutable_builtin_call(), i);
},
[&](const ir::Construct* i) { InstructionConstruct(*inst_out.mutable_construct(), i); },
+ [&](const ir::Continue* i) { InstructionContinue(*inst_out.mutable_continue_(), i); },
[&](const ir::Convert* i) { InstructionConvert(*inst_out.mutable_convert(), i); },
[&](const ir::Discard* i) { InstructionDiscard(*inst_out.mutable_discard(), i); },
[&](const ir::ExitIf* i) { InstructionExitIf(*inst_out.mutable_exit_if(), i); },
+ [&](const ir::ExitLoop* i) { InstructionExitLoop(*inst_out.mutable_exit_loop(), i); },
[&](const ir::ExitSwitch* i) {
InstructionExitSwitch(*inst_out.mutable_exit_switch(), i);
},
@@ -176,6 +191,10 @@
[&](const ir::LoadVectorElement* i) {
InstructionLoadVectorElement(*inst_out.mutable_load_vector_element(), i);
},
+ [&](const ir::Loop* i) { InstructionLoop(*inst_out.mutable_loop(), i); },
+ [&](const ir::NextIteration* i) {
+ InstructionNextIteration(*inst_out.mutable_next_iteration(), i);
+ },
[&](const ir::Return* i) { InstructionReturn(*inst_out.mutable_return_(), i); },
[&](const ir::Store* i) { InstructionStore(*inst_out.mutable_store(), i); },
[&](const ir::StoreVectorElement* i) {
@@ -201,6 +220,8 @@
binary_out.set_op(BinaryOp(binary_in->Op()));
}
+ void InstructionBreakIf(pb::InstructionBreakIf&, const ir::BreakIf*) {}
+
void InstructionBuiltinCall(pb::InstructionBuiltinCall& call_out,
const ir::CoreBuiltinCall* call_in) {
call_out.set_builtin(BuiltinFn(call_in->Func()));
@@ -208,6 +229,8 @@
void InstructionConstruct(pb::InstructionConstruct&, const ir::Construct*) {}
+ void InstructionContinue(pb::InstructionContinue&, const ir::Continue*) {}
+
void InstructionConvert(pb::InstructionConvert&, const ir::Convert*) {}
void InstructionIf(pb::InstructionIf& if_out, const ir::If* if_in) {
@@ -223,6 +246,8 @@
void InstructionExitIf(pb::InstructionExitIf&, const ir::ExitIf*) {}
+ void InstructionExitLoop(pb::InstructionExitLoop&, const ir::ExitLoop*) {}
+
void InstructionExitSwitch(pb::InstructionExitSwitch&, const ir::ExitSwitch*) {}
void InstructionLet(pb::InstructionLet&, const ir::Let*) {}
@@ -232,6 +257,18 @@
void InstructionLoadVectorElement(pb::InstructionLoadVectorElement&,
const ir::LoadVectorElement*) {}
+ void InstructionLoop(pb::InstructionLoop& loop_out, const ir::Loop* loop_in) {
+ if (loop_in->HasInitializer()) {
+ loop_out.set_initalizer(Block(loop_in->Initializer()));
+ }
+ loop_out.set_body(Block(loop_in->Body()));
+ if (loop_in->HasContinuing()) {
+ loop_out.set_continuing(Block(loop_in->Continuing()));
+ }
+ }
+
+ void InstructionNextIteration(pb::InstructionNextIteration&, const ir::NextIteration*) {}
+
void InstructionReturn(pb::InstructionReturn&, const ir::Return*) {}
void InstructionStore(pb::InstructionStore&, const ir::Store*) {}
@@ -378,6 +415,8 @@
}
return values_.GetOrCreate(value_in, [&] {
auto& value_out = *mod_out_.add_values();
+ auto id = static_cast<uint32_t>(mod_out_.values().size());
+
tint::Switch(
value_in,
[&](const ir::InstructionResult* v) {
@@ -386,11 +425,14 @@
[&](const ir::FunctionParam* v) {
FunctionParameter(*value_out.mutable_function_parameter(), v);
},
+ [&](const ir::BlockParam* v) {
+ BlockParameter(*value_out.mutable_block_parameter(), v);
+ },
[&](const ir::Function* v) { value_out.set_function(Function(v)); },
[&](const ir::Constant* v) { value_out.set_constant(ConstantValue(v->Value())); },
TINT_ICE_ON_NO_MATCH);
- return static_cast<uint32_t>(mod_out_.values().size());
+ return id;
});
}
@@ -408,6 +450,13 @@
}
}
+ void BlockParameter(pb::BlockParameter& param_out, const ir::BlockParam* param_in) {
+ param_out.set_type(Type(param_in->Type()));
+ if (auto name = mod_in_.NameOf(param_in); name.IsValid()) {
+ param_out.set_name(name.Name());
+ }
+ }
+
////////////////////////////////////////////////////////////////////////////
// ConstantValues
////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/lang/core/ir/binary/ir.proto b/src/tint/lang/core/ir/binary/ir.proto
index 70ce8ae..d0ad749 100644
--- a/src/tint/lang/core/ir/binary/ir.proto
+++ b/src/tint/lang/core/ir/binary/ir.proto
@@ -112,7 +112,8 @@
uint32 function = 1; // Module.functions
InstructionResult instruction_result = 2;
FunctionParameter function_parameter = 3;
- uint32 constant = 4; // Module.constant_values
+ BlockParameter block_parameter = 4;
+ uint32 constant = 5; // Module.constant_values
}
}
@@ -126,6 +127,11 @@
optional string name = 2;
}
+message BlockParameter {
+ uint32 type = 1; // Module.types
+ optional string name = 2;
+}
+
////////////////////////////////////////////////////////////////////////////////
// ConstantValues
////////////////////////////////////////////////////////////////////////////////
@@ -188,6 +194,7 @@
message Block {
repeated uint32 parameters = 1; // Module.values
repeated Instruction instructions = 2;
+ bool is_multi_in = 3;
}
////////////////////////////////////////////////////////////////////////////////
@@ -216,8 +223,13 @@
InstructionSwizzle swizzle = 19;
InstructionIf if = 20;
InstructionSwitch switch = 21;
- InstructionExitIf exit_if = 22;
- InstructionExitSwitch exit_switch = 23;
+ InstructionLoop loop = 22;
+ InstructionExitIf exit_if = 23;
+ InstructionExitSwitch exit_switch = 24;
+ InstructionExitLoop exit_loop = 25;
+ InstructionNextIteration next_iteration = 26;
+ InstructionContinue continue = 27;
+ InstructionBreakIf break_if = 28;
}
}
@@ -276,10 +288,18 @@
repeated SwitchCase cases = 1; // Module.blocks
}
+message InstructionLoop {
+ optional uint32 initalizer = 1; // Module.blocks
+ optional uint32 body = 2; // Module.blocks
+ optional uint32 continuing = 3; // Module.blocks
+}
+
message InstructionExitIf {}
message InstructionExitSwitch {}
+message InstructionExitLoop {}
+
message SwitchCase {
uint32 block = 1; // Module.blocks
repeated uint32 selectors = 2; // Module.constant_values
@@ -291,6 +311,12 @@
uint32 binding = 2;
}
+message InstructionNextIteration {}
+
+message InstructionContinue {}
+
+message InstructionBreakIf {}
+
////////////////////////////////////////////////////////////////////////////////
// Attributes
////////////////////////////////////////////////////////////////////////////////
@@ -300,7 +326,7 @@
optional uint32 color = 3;
optional BuiltinValue builtin = 4;
optional AttributesInterpolation interpolation = 5;
- optional bool invariant = 6;
+ bool invariant = 6;
}
message AttributesInterpolation {
diff --git a/src/tint/lang/core/ir/binary/roundtrip_test.cc b/src/tint/lang/core/ir/binary/roundtrip_test.cc
index 28d5e7c..e4a2a58 100644
--- a/src/tint/lang/core/ir/binary/roundtrip_test.cc
+++ b/src/tint/lang/core/ir/binary/roundtrip_test.cc
@@ -527,5 +527,73 @@
RUN_TEST();
}
+TEST_F(IRBinaryRoundtripTest, LoopBody) {
+ auto* fn = b.Function("Function", ty.i32());
+ b.Append(fn->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Body(), [&] { b.Return(fn, 1_i); });
+ });
+ RUN_TEST();
+}
+
+TEST_F(IRBinaryRoundtripTest, LoopInitBody) {
+ auto* fn = b.Function("Function", ty.i32());
+ b.Append(fn->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Initializer(), [&] {
+ b.Let("L", 1_i);
+ b.NextIteration(loop);
+ });
+ b.Append(loop->Body(), [&] { b.Return(fn, 2_i); });
+ });
+ RUN_TEST();
+}
+
+TEST_F(IRBinaryRoundtripTest, LoopInitBodyCont) {
+ auto* fn = b.Function("Function", ty.i32());
+ b.Append(fn->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Initializer(), [&] {
+ b.Let("L", 1_i);
+ b.NextIteration(loop);
+ });
+ b.Append(loop->Body(), [&] { b.Continue(loop); });
+ b.Append(loop->Continuing(), [&] { b.BreakIf(loop, false); });
+ });
+ RUN_TEST();
+}
+
+TEST_F(IRBinaryRoundtripTest, LoopResults) {
+ auto* fn = b.Function("Function", ty.i32());
+ b.Append(fn->Block(), [&] {
+ auto* loop = b.Loop();
+ auto* res = b.InstructionResult<i32>();
+ loop->SetResults(Vector{res});
+ b.Append(loop->Body(), [&] { b.ExitLoop(loop, 1_i); });
+ b.Return(fn, res);
+ });
+ RUN_TEST();
+}
+
+TEST_F(IRBinaryRoundtripTest, LoopBlockParams) {
+ auto* fn = b.Function("Function", ty.void_());
+ b.Append(fn->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Initializer(), [&] {
+ b.Let("L", 1_i);
+ b.NextIteration(loop);
+ });
+ auto* x = b.BlockParam<i32>("x");
+ auto* y = b.BlockParam<f32>("y");
+ loop->Body()->SetParams({x, y});
+ b.Append(loop->Body(), [&] { b.Continue(loop, 1_u, true); });
+ auto* z = b.BlockParam<u32>("z");
+ auto* w = b.BlockParam<bool>("w");
+ loop->Continuing()->SetParams({z, w});
+ b.Append(loop->Continuing(), [&] { b.BreakIf(loop, false, 3_i, 4_f); });
+ b.Return(fn);
+ });
+ RUN_TEST();
+}
} // namespace
} // namespace tint::core::ir::binary
diff --git a/src/tint/lang/core/ir/break_if.cc b/src/tint/lang/core/ir/break_if.cc
index 3afa0aa..dc312e6 100644
--- a/src/tint/lang/core/ir/break_if.cc
+++ b/src/tint/lang/core/ir/break_if.cc
@@ -40,6 +40,8 @@
namespace tint::core::ir {
+BreakIf::BreakIf() = default;
+
BreakIf::BreakIf(Value* condition, ir::Loop* loop, VectorRef<Value*> args) : loop_(loop) {
TINT_ASSERT(loop_);
@@ -60,4 +62,14 @@
return ctx.ir.instructions.Create<BreakIf>(cond, loop, args);
}
+void BreakIf::SetLoop(ir::Loop* loop) {
+ if (loop_ && loop_->Body()) {
+ loop_->Body()->RemoveInboundSiblingBranch(this);
+ }
+ loop_ = loop;
+ if (loop) {
+ loop->Body()->AddInboundSiblingBranch(this);
+ }
+}
+
} // namespace tint::core::ir
diff --git a/src/tint/lang/core/ir/break_if.h b/src/tint/lang/core/ir/break_if.h
index bc8dbdf..85508e2 100644
--- a/src/tint/lang/core/ir/break_if.h
+++ b/src/tint/lang/core/ir/break_if.h
@@ -51,6 +51,9 @@
/// The base offset in Operands() for the arguments
static constexpr size_t kArgsOperandOffset = 1;
+ /// Constructor (no operands, no loop)
+ BreakIf();
+
/// Constructor
/// @param condition the break condition
/// @param loop the loop containing the break-if
@@ -76,6 +79,9 @@
/// @returns the loop containing the break-if
const ir::Loop* Loop() const { return loop_; }
+ /// @param loop the new loop containing the continue
+ void SetLoop(ir::Loop* loop);
+
/// @returns the friendly name for the instruction
std::string FriendlyName() const override { return "break_if"; }
diff --git a/src/tint/lang/core/ir/builder.h b/src/tint/lang/core/ir/builder.h
index ed5ed77..ebbd851 100644
--- a/src/tint/lang/core/ir/builder.h
+++ b/src/tint/lang/core/ir/builder.h
@@ -1257,6 +1257,16 @@
/// @returns the value
ir::BlockParam* BlockParam(std::string_view name, const core::type::Type* type);
+ /// Creates a new `BlockParam` with a name.
+ /// @tparam TYPE the parameter type
+ /// @param name the parameter name
+ /// @returns the value
+ template <typename TYPE>
+ ir::BlockParam* BlockParam(std::string_view name) {
+ auto* type = ir.Types().Get<TYPE>();
+ return BlockParam(name, type);
+ }
+
/// Creates a new `FunctionParam`
/// @param type the parameter type
/// @returns the value
diff --git a/src/tint/lang/core/ir/continue.cc b/src/tint/lang/core/ir/continue.cc
index 4eebd7a..66c9115 100644
--- a/src/tint/lang/core/ir/continue.cc
+++ b/src/tint/lang/core/ir/continue.cc
@@ -40,6 +40,8 @@
namespace tint::core::ir {
+Continue::Continue() = default;
+
Continue::Continue(ir::Loop* loop, VectorRef<Value*> args) : loop_(loop) {
TINT_ASSERT(loop_);
@@ -59,4 +61,14 @@
return ctx.ir.instructions.Create<Continue>(loop, args);
}
+void Continue::SetLoop(ir::Loop* loop) {
+ if (loop_ && loop_->Body()) {
+ loop_->Body()->RemoveInboundSiblingBranch(this);
+ }
+ loop_ = loop;
+ if (loop) {
+ loop->Body()->AddInboundSiblingBranch(this);
+ }
+}
+
} // namespace tint::core::ir
diff --git a/src/tint/lang/core/ir/continue.h b/src/tint/lang/core/ir/continue.h
index 6233282..803ae8f 100644
--- a/src/tint/lang/core/ir/continue.h
+++ b/src/tint/lang/core/ir/continue.h
@@ -47,6 +47,9 @@
/// The base offset in Operands() for the args
static constexpr size_t kArgsOperandOffset = 0;
+ /// Constructor (no operands, no loop)
+ Continue();
+
/// Constructor
/// @param loop the loop owning the continue block
/// @param args the arguments for the MultiInBlock
@@ -62,6 +65,9 @@
/// @returns the loop owning the continue block
const ir::Loop* Loop() const { return loop_; }
+ /// @param loop the new loop owning the continue block
+ void SetLoop(ir::Loop* loop);
+
/// @returns the friendly name for the instruction
std::string FriendlyName() const override { return "continue"; }
diff --git a/src/tint/lang/core/ir/exit_loop.cc b/src/tint/lang/core/ir/exit_loop.cc
index fe8da1b..31289b5 100644
--- a/src/tint/lang/core/ir/exit_loop.cc
+++ b/src/tint/lang/core/ir/exit_loop.cc
@@ -39,6 +39,8 @@
namespace tint::core::ir {
+ExitLoop::ExitLoop() = default;
+
ExitLoop::ExitLoop(ir::Loop* loop, VectorRef<Value*> args /* = tint::Empty */) {
SetLoop(loop);
AddOperands(ExitLoop::kArgsOperandOffset, std::move(args));
diff --git a/src/tint/lang/core/ir/exit_loop.h b/src/tint/lang/core/ir/exit_loop.h
index 0686fc2..d6cde81 100644
--- a/src/tint/lang/core/ir/exit_loop.h
+++ b/src/tint/lang/core/ir/exit_loop.h
@@ -46,6 +46,9 @@
/// The base offset in Operands() for the args
static constexpr size_t kArgsOperandOffset = 0;
+ /// Constructor (no operands, no loop)
+ ExitLoop();
+
/// Constructor
/// @param loop the loop being exited
/// @param args the target MultiInBlock arguments
diff --git a/src/tint/lang/core/ir/loop.cc b/src/tint/lang/core/ir/loop.cc
index a5e3519..64a5a6c 100644
--- a/src/tint/lang/core/ir/loop.cc
+++ b/src/tint/lang/core/ir/loop.cc
@@ -36,6 +36,8 @@
namespace tint::core::ir {
+Loop::Loop() = default;
+
Loop::Loop(ir::Block* i, ir::MultiInBlock* b, ir::MultiInBlock* c)
: initializer_(i), body_(b), continuing_(c) {
TINT_ASSERT(initializer_);
@@ -84,8 +86,42 @@
}
}
-bool Loop::HasInitializer() {
+bool Loop::HasInitializer() const {
return initializer_->Terminator() != nullptr;
}
+void Loop::SetInitializer(ir::Block* block) {
+ if (initializer_ && initializer_->Parent() == this) {
+ initializer_->SetParent(nullptr);
+ }
+ initializer_ = block;
+ if (block) {
+ block->SetParent(this);
+ }
+}
+
+void Loop::SetBody(ir::MultiInBlock* block) {
+ if (body_ && body_->Parent() == this) {
+ body_->SetParent(nullptr);
+ }
+ body_ = block;
+ if (block) {
+ block->SetParent(this);
+ }
+}
+
+bool Loop::HasContinuing() const {
+ return continuing_->Terminator() != nullptr;
+}
+
+void Loop::SetContinuing(ir::MultiInBlock* block) {
+ if (continuing_ && continuing_->Parent() == this) {
+ continuing_->SetParent(nullptr);
+ }
+ continuing_ = block;
+ if (block) {
+ block->SetParent(this);
+ }
+}
+
} // namespace tint::core::ir
diff --git a/src/tint/lang/core/ir/loop.h b/src/tint/lang/core/ir/loop.h
index c22b8b97..c87760e 100644
--- a/src/tint/lang/core/ir/loop.h
+++ b/src/tint/lang/core/ir/loop.h
@@ -72,6 +72,9 @@
/// ```
class Loop final : public Castable<Loop, ControlInstruction> {
public:
+ /// Constructor (no results, no operands, no blocks)
+ Loop();
+
/// Constructor
/// @param i the initializer block
/// @param b the body block
@@ -93,7 +96,10 @@
/// @returns true if the loop uses an initializer block. If true, then the Loop first branches
/// to the initializer block, otherwise it first branches to the body block.
- bool HasInitializer();
+ bool HasInitializer() const;
+
+ /// @param block the new switch initializer block
+ void SetInitializer(ir::Block* block);
/// @returns the switch start block
ir::MultiInBlock* Body() { return body_; }
@@ -101,12 +107,22 @@
/// @returns the switch start block
const ir::MultiInBlock* Body() const { return body_; }
+ /// @param block the new switch body block
+ void SetBody(ir::MultiInBlock* block);
+
/// @returns the switch continuing block
ir::MultiInBlock* Continuing() { return continuing_; }
/// @returns the switch continuing block
const ir::MultiInBlock* Continuing() const { return continuing_; }
+ /// @returns true if the loop uses an continuing block. If true, then the Loop first branches
+ /// to the continuing block, otherwise it first branches to the body block.
+ bool HasContinuing() const;
+
+ /// @param block the new switch continuing block
+ void SetContinuing(ir::MultiInBlock* block);
+
/// @returns the friendly name for the instruction
std::string FriendlyName() const override { return "loop"; }
diff --git a/src/tint/lang/core/ir/multi_in_block.cc b/src/tint/lang/core/ir/multi_in_block.cc
index f10f741..c57c622 100644
--- a/src/tint/lang/core/ir/multi_in_block.cc
+++ b/src/tint/lang/core/ir/multi_in_block.cc
@@ -63,11 +63,13 @@
}
void MultiInBlock::AddInboundSiblingBranch(ir::Terminator* node) {
- TINT_ASSERT(node != nullptr);
+ TINT_ASSERT_OR_RETURN(node != nullptr);
+ inbound_sibling_branches_.Push(node);
+}
- if (node) {
- inbound_sibling_branches_.Push(node);
- }
+void MultiInBlock::RemoveInboundSiblingBranch(ir::Terminator* node) {
+ TINT_ASSERT_OR_RETURN(node != nullptr);
+ inbound_sibling_branches_.EraseIf([node](ir::Terminator* i) { return i == node; });
}
} // namespace tint::core::ir
diff --git a/src/tint/lang/core/ir/multi_in_block.h b/src/tint/lang/core/ir/multi_in_block.h
index 76915f1..443e716 100644
--- a/src/tint/lang/core/ir/multi_in_block.h
+++ b/src/tint/lang/core/ir/multi_in_block.h
@@ -71,6 +71,10 @@
/// @param branch the branch to add
void AddInboundSiblingBranch(ir::Terminator* branch);
+ /// Removes the given branch to the list of branches made to this block by sibling blocks
+ /// @param branch the branch to remove
+ void RemoveInboundSiblingBranch(ir::Terminator* branch);
+
private:
Vector<BlockParam*, 2> params_;
Vector<ir::Terminator*, 2> inbound_sibling_branches_;
diff --git a/src/tint/lang/core/ir/next_iteration.cc b/src/tint/lang/core/ir/next_iteration.cc
index 63291b6..5de66c7 100644
--- a/src/tint/lang/core/ir/next_iteration.cc
+++ b/src/tint/lang/core/ir/next_iteration.cc
@@ -39,6 +39,8 @@
namespace tint::core::ir {
+NextIteration::NextIteration() = default;
+
NextIteration::NextIteration(ir::Loop* loop, VectorRef<Value*> args /* = tint::Empty */)
: loop_(loop) {
TINT_ASSERT(loop_);
@@ -58,4 +60,14 @@
return ctx.ir.instructions.Create<NextIteration>(new_loop, args);
}
+void NextIteration::SetLoop(ir::Loop* loop) {
+ if (loop_ && loop_->Body()) {
+ loop_->Body()->RemoveInboundSiblingBranch(this);
+ }
+ loop_ = loop;
+ if (loop) {
+ loop->Body()->AddInboundSiblingBranch(this);
+ }
+}
+
} // namespace tint::core::ir
diff --git a/src/tint/lang/core/ir/next_iteration.h b/src/tint/lang/core/ir/next_iteration.h
index f859ee0..c80c10f 100644
--- a/src/tint/lang/core/ir/next_iteration.h
+++ b/src/tint/lang/core/ir/next_iteration.h
@@ -47,6 +47,9 @@
/// The base offset in Operands() for the args
static constexpr size_t kArgsOperandOffset = 0;
+ /// Constructor (no operands, no loop)
+ NextIteration();
+
/// Constructor
/// @param loop the loop being iterated
/// @param args the arguments for the MultiInBlock
@@ -62,6 +65,9 @@
/// @returns the loop being iterated
const ir::Loop* Loop() const { return loop_; }
+ /// @param loop the new loop being iterated
+ void SetLoop(ir::Loop* loop);
+
/// @returns the friendly name for the instruction
std::string FriendlyName() const override { return "next_iteration"; }