[ir][spirv-writer] Add MergeReturn transform
Each return instruction is replaced with instructions that set a flag
to indicate the function is returning, capture the return value, and
exit from the current block to the nearest merge block. The code in
the merge block is then conditionalized based on that flag. A single
return instructions is then added to the final merge block of the
function.
This is needed to ensure that we adhere to SPIR-V's uniformity
requirements.
Use it in the SPIR-V writer.
Bug: tint:1906
Change-Id: I80edef751acc8e807a78e55549244a61726b462e
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/137340
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index b4161d2..0ec0f1a 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -521,6 +521,8 @@
"ir/transform/add_empty_entry_point.h",
"ir/transform/block_decorated_structs.cc",
"ir/transform/block_decorated_structs.h",
+ "ir/transform/merge_return.cc",
+ "ir/transform/merge_return.h",
"ir/transform/var_for_dynamic_index.cc",
"ir/transform/var_for_dynamic_index.h",
]
@@ -1830,6 +1832,7 @@
sources = [
"ir/transform/add_empty_entry_point_test.cc",
"ir/transform/block_decorated_structs_test.cc",
+ "ir/transform/merge_return_test.cc",
"ir/transform/test_helper.h",
"ir/transform/var_for_dynamic_index_test.cc",
]
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 6728f08..d9f60b4 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -811,6 +811,8 @@
ir/transform/add_empty_entry_point.h
ir/transform/block_decorated_structs.cc
ir/transform/block_decorated_structs.h
+ ir/transform/merge_return.cc
+ ir/transform/merge_return.h
ir/transform/transform.cc
ir/transform/transform.h
ir/transform/var_for_dynamic_index.cc
@@ -1550,6 +1552,7 @@
ir/swizzle_test.cc
ir/transform/add_empty_entry_point_test.cc
ir/transform/block_decorated_structs_test.cc
+ ir/transform/merge_return_test.cc
ir/transform/var_for_dynamic_index_test.cc
ir/unary_test.cc
ir/user_call_test.cc
diff --git a/src/tint/ir/builder.h b/src/tint/ir/builder.h
index 98795b7..dcae264 100644
--- a/src/tint/ir/builder.h
+++ b/src/tint/ir/builder.h
@@ -192,6 +192,10 @@
return Constant(std::forward<T>(number));
}
+ /// Pass-through overload for nullptr values
+ /// @returns nullptr
+ ir::Value* Value(std::nullptr_t) { return nullptr; }
+
/// Pass-through overload for Value()
/// @param v the ir::Value pointer
/// @returns @p v
diff --git a/src/tint/ir/disassembler.cc b/src/tint/ir/disassembler.cc
index b7e3bc5..8ba344f 100644
--- a/src/tint/ir/disassembler.cc
+++ b/src/tint/ir/disassembler.cc
@@ -344,7 +344,13 @@
[&](ir::InstructionResult* rv) { out_ << "%" << IdOf(rv); },
[&](ir::BlockParam* p) { out_ << "%" << IdOf(p) << ":" << p->Type()->FriendlyName(); },
[&](ir::FunctionParam* p) { out_ << "%" << IdOf(p); },
- [&](Default) { out_ << "Unknown value: " << val->TypeInfo().name; });
+ [&](Default) {
+ if (val == nullptr) {
+ out_ << "undef";
+ } else {
+ out_ << "Unknown value: " << val->TypeInfo().name;
+ }
+ });
}
void Disassembler::EmitInstructionName(std::string_view name, Instruction* inst) {
diff --git a/src/tint/ir/operand_instruction.h b/src/tint/ir/operand_instruction.h
index 43a43d1..999a883 100644
--- a/src/tint/ir/operand_instruction.h
+++ b/src/tint/ir/operand_instruction.h
@@ -74,8 +74,8 @@
operands_.Push(value);
}
- /// Append a list of non-null operands to the operand list for this instruction.
- /// @param start_idx the index from whic the values should start
+ /// Append a list of operands to the operand list for this instruction.
+ /// @param start_idx the index from which the values should start
/// @param values the operand values to append
void AddOperands(size_t start_idx, utils::VectorRef<ir::Value*> values) {
size_t idx = start_idx;
diff --git a/src/tint/ir/transform/merge_return.cc b/src/tint/ir/transform/merge_return.cc
new file mode 100644
index 0000000..901a411
--- /dev/null
+++ b/src/tint/ir/transform/merge_return.cc
@@ -0,0 +1,292 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/transform/merge_return.h"
+
+#include <utility>
+
+#include "src/tint/ir/builder.h"
+#include "src/tint/ir/module.h"
+#include "src/tint/switch.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ir::transform::MergeReturn);
+
+using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ir::transform {
+
+MergeReturn::MergeReturn() = default;
+
+MergeReturn::~MergeReturn() = default;
+
+/// PIMPL state for the transform, for a single function.
+struct MergeReturn::State {
+ /// The IR module.
+ Module* ir = nullptr;
+ /// The IR builder.
+ Builder b{*ir};
+ /// The type manager.
+ type::Manager& ty{ir->Types()};
+
+ /// The "is returning" flag.
+ Var* return_flag = nullptr;
+
+ /// The return value.
+ Var* return_val = nullptr;
+
+ /// A set of merge blocks that have already been processed.
+ utils::Hashset<Block*, 8> processed_merges;
+
+ /// The final merge block that will contain the unique return instruction.
+ Block* final_merge = nullptr;
+
+ /// Track whether the return flag was actually used to conditionalize a merge block.
+ bool uses_return_flag = false;
+
+ /// Constructor
+ /// @param mod the module
+ explicit State(Module* mod) : ir(mod) {}
+
+ /// Get the nearest non-merge parent block of `block`.
+ /// @param block the block
+ /// @returns the enclosing non-merge block
+ Block* GetEnclosingNonMergeBlock(Block* block) {
+ while (block->Is<MultiInBlock>()) {
+ auto* parent = block->Parent();
+ if (auto* loop = parent->As<Loop>()) {
+ if (block != loop->Merge()) {
+ break;
+ }
+ }
+ block = parent->Block();
+ }
+ return block;
+ }
+
+ /// Process the function.
+ /// @param func the function to process
+ void Process(Function* func) {
+ // Find all of the return instructions in the function.
+ utils::Vector<Return*, 4> returns;
+ for (const auto& usage : func->Usages()) {
+ if (auto* ret = usage.instruction->As<Return>()) {
+ returns.Push(ret);
+ }
+ }
+
+ // If there are no returns, or just a single return at the end of the function (potentially
+ // inside a nested merge block), then nothing needs to be done.
+ if (returns.Length() == 0 ||
+ (returns.Length() == 1 &&
+ GetEnclosingNonMergeBlock(returns[0]->Block()) == func->StartTarget())) {
+ return;
+ }
+
+ // Create a boolean variable that can be used to check whether the function is returning.
+ return_flag = b.Var(ty.ptr<function, bool>());
+ return_flag->SetInitializer(b.Constant(false));
+ func->StartTarget()->Prepend(return_flag);
+ ir->SetName(return_flag, "return_flag");
+
+ // Create a variable to hold the return value if needed.
+ if (!func->ReturnType()->Is<type::Void>()) {
+ return_val = b.Var(ty.ptr(function, func->ReturnType()));
+ func->StartTarget()->Prepend(return_val);
+ ir->SetName(return_val, "return_value");
+ }
+
+ // Process all of the return instructions in the function.
+ for (auto* ret : returns) {
+ ProcessReturn(ret);
+ }
+
+ // Add the unique return instruction to the final merge block if needed.
+ if (final_merge) {
+ if (return_val) {
+ auto* retval = final_merge->Append(b.Load(return_val));
+ final_merge->Append(b.Return(func, retval));
+ } else {
+ final_merge->Append(b.Return(func));
+ }
+ }
+
+ // If the return flag was never actually read from, remove it and the corresponding stores.
+ if (!uses_return_flag) {
+ for (const auto& use : return_flag->Result()->Usages()) {
+ use.instruction->Remove();
+ }
+ return_flag->Remove();
+ }
+ }
+
+ /// Process a return instruction.
+ /// @param ret the return instruction
+ void ProcessReturn(Return* ret) {
+ // If this return is at the end of the function, with no value, then we can leave it alone.
+ if (GetEnclosingNonMergeBlock(ret->Block()) == ret->Func()->StartTarget() &&
+ ret->Block()->Length() == 1 && !ret->Value()) {
+ return;
+ }
+
+ // Set the "is returning" flag to `true`, and record the return value if present.
+ b.Store(return_flag, b.Constant(true))->InsertBefore(ret);
+ if (return_val) {
+ b.Store(return_val, ret->Value())->InsertBefore(ret);
+ }
+
+ // Exit from the containing block, which will recursively insert conditionals into the
+ // containing merge blocks as necessary, eventually inserting a unique return instruction.
+ ExitFromBlock(ret->Block());
+ ret->Remove();
+ if (ret->Value()) {
+ ret->Value()->RemoveUsage({ret, 0u});
+ }
+ }
+
+ /// Process a merge block by wrapping its existing instructions (if any) in a conditional such
+ /// that they will only execute if we are not returning.
+ /// @param merge the merge block to process
+ void ProcessMerge(MultiInBlock* merge) {
+ if (processed_merges.Contains(merge)) {
+ return;
+ }
+ processed_merges.Add(merge);
+
+ // If the merge block was empty, we just need to exit from it.
+ if (merge->IsEmpty()) {
+ ExitFromBlock(merge);
+ return;
+ }
+
+ if (merge->Length() == 1) {
+ // If the block only contains an exit_{if,loop,switch}, we can skip adding a conditional
+ // around its contents and just recurse into the parent merge block.
+ if (utils::IsAnyOf<ExitIf, ExitLoop, ExitSwitch>(merge->Branch())) {
+ tint::Switch(
+ merge->Branch(), //
+ [&](If* ifelse) { ProcessMerge(ifelse->Merge()); },
+ [&](Loop* loop) { ProcessMerge(loop->Merge()); },
+ [&](Switch* swtch) { ProcessMerge(swtch->Merge()); });
+ return;
+ }
+
+ // If the block only contains a return (with no value), we don't need to do anything.
+ if (auto* ret = utils::As<Return>(merge->Branch())) {
+ if (!ret->Value()) {
+ return;
+ }
+ }
+ }
+
+ // Wrap the existing contents of the merge block in a conditional so that it will only
+ // execute if the "is returning" flag is `false`.
+ uses_return_flag = true;
+ auto* condition = b.Load(return_flag);
+ auto* ifelse = b.If(condition);
+
+ // Move all pre-existing instructions to the new false block.
+ while (!merge->IsEmpty()) {
+ auto* inst = merge->Front();
+ inst->Remove();
+ ifelse->False()->Append(inst);
+ }
+
+ // Now the merge block will just contain the new conditional.
+ merge->Append(condition);
+ merge->Append(ifelse);
+
+ utils::Vector<Value*, 4> block_args_from_true;
+ utils::Vector<BlockParam*, 4> merge_block_params;
+ if (auto* exitif = ifelse->False()->Back()->As<ExitIf>()) {
+ // If the previous terminator was an exit_if, we need replace it with one that exits to
+ // the new merge block, and propagate the original basic block arguments if any.
+ // The exit_if from the `true` block will just pass undef values to the merge block.
+ utils::Vector<Value*, 4> block_args_from_false;
+ for (uint32_t i = 0; i < exitif->Args().Length(); i++) {
+ block_args_from_true.Push(nullptr);
+ block_args_from_false.Push(exitif->Args()[i]);
+ merge_block_params.Push(b.BlockParam(exitif->If()->Merge()->Params()[i]->Type()));
+ }
+ exitif->ReplaceWith(b.ExitIf(ifelse, block_args_from_false));
+ } else {
+ // If this merge block was the final merge block of the function, it won't have a branch
+ // yet. Add an `exit_if` to the new merge block, and record the new merge block as the
+ // new final merge block.
+ if (merge == final_merge) {
+ ifelse->False()->Append(b.ExitIf(ifelse));
+ final_merge = ifelse->Merge();
+ }
+ }
+
+ // Exit from the `true` block to the new merge block.
+ ifelse->True()->Append(b.ExitIf(ifelse, block_args_from_true));
+
+ // Exit from the new merge block, which will recursively process the parent merge.
+ ifelse->Merge()->SetParams(merge_block_params);
+ ExitFromBlock(ifelse->Merge(), merge_block_params);
+
+ // We never need to process the merge that we've just added, as it only exits.
+ processed_merges.Add(ifelse->Merge());
+ }
+
+ /// Add an exit_{if,loop,switch} instruction to `block`, and process the target merge block.
+ /// @param block the block to exit from
+ /// @param args the optional basic block arguments
+ void ExitFromBlock(Block* block, utils::VectorRef<Value*> args = utils::Empty) {
+ // Helper to get the block arguments for an instruction that is exiting to `merge`.
+ auto block_args = [&](auto* merge) -> utils::Vector<Value*, 4> {
+ // If arguments were explicitly provided, use those.
+ if (!args.IsEmpty()) {
+ return args;
+ }
+
+ // Otherwise, we will pass a list of `undef` values.
+ utils::Vector<Value*, 4> undef_args;
+ undef_args.Resize(merge->Params().Length(), nullptr);
+ return undef_args;
+ };
+
+ auto* parent_control_flow = GetEnclosingNonMergeBlock(block)->Parent();
+ tint::Switch(
+ parent_control_flow,
+ [&](If* ifelse) {
+ ProcessMerge(ifelse->Merge());
+ block->Append(b.ExitIf(ifelse, block_args(ifelse->Merge())));
+ },
+ [&](Loop* loop) {
+ ProcessMerge(loop->Merge());
+ block->Append(b.ExitLoop(loop, block_args(loop->Merge())));
+ },
+ [&](Switch* swtch) {
+ ProcessMerge(swtch->Merge());
+ block->Append(b.ExitSwitch(swtch, block_args(swtch->Merge())));
+ },
+ [&](Default) {
+ // This is the top-level merge block, so just record it so that we can add the
+ // unique return instruction to it later.
+ final_merge = block;
+ });
+ }
+};
+
+void MergeReturn::Run(Module* ir, const DataMap&, DataMap&) const {
+ // Process each function.
+ for (auto* func : ir->functions) {
+ State state(ir);
+ state.Process(func);
+ }
+}
+
+} // namespace tint::ir::transform
diff --git a/src/tint/ir/transform/merge_return.h b/src/tint/ir/transform/merge_return.h
new file mode 100644
index 0000000..d127366
--- /dev/null
+++ b/src/tint/ir/transform/merge_return.h
@@ -0,0 +1,40 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_TRANSFORM_MERGE_RETURN_H_
+#define SRC_TINT_IR_TRANSFORM_MERGE_RETURN_H_
+
+#include "src/tint/ir/transform/transform.h"
+
+namespace tint::ir::transform {
+
+/// MergeReturn is a transform merges multiple return statements in a function into a single return
+/// at the end of the function.
+class MergeReturn final : public utils::Castable<MergeReturn, Transform> {
+ public:
+ /// Constructor
+ MergeReturn();
+ /// Destructor
+ ~MergeReturn() override;
+
+ /// @copydoc Transform::Run
+ void Run(ir::Module* module, const DataMap& inputs, DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ir::transform
+
+#endif // SRC_TINT_IR_TRANSFORM_MERGE_RETURN_H_
diff --git a/src/tint/ir/transform/merge_return_test.cc b/src/tint/ir/transform/merge_return_test.cc
new file mode 100644
index 0000000..747c995
--- /dev/null
+++ b/src/tint/ir/transform/merge_return_test.cc
@@ -0,0 +1,2413 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/transform/merge_return.h"
+
+#include <utility>
+
+#include "src/tint/ir/transform/test_helper.h"
+
+namespace tint::ir::transform {
+namespace {
+
+using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
+
+using IR_MergeReturnTest = TransformTest;
+
+TEST_F(IR_MergeReturnTest, NoModify_SingleReturnInRootBlock) {
+ auto* in = b.FunctionParam(ty.i32());
+ auto* func = b.Function("foo", ty.i32());
+ func->SetParams({in});
+ mod.functions.Push(func);
+
+ auto sb = b.With(func->StartTarget());
+ sb.Return(func, sb.Add(ty.i32(), in, 1_i));
+
+ auto* src = R"(
+%foo = func(%2:i32):i32 -> %b1 {
+ %b1 = block {
+ %3:i32 = add %2, 1i
+ ret %3
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = src;
+
+ Run<MergeReturn>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_MergeReturnTest, NoModify_SingleReturnInMergeBlock) {
+ auto* in = b.FunctionParam(ty.i32());
+ auto* cond = b.FunctionParam(ty.bool_());
+ auto* func = b.Function("foo", ty.i32());
+ func->SetParams({in});
+ mod.functions.Push(func);
+
+ auto sb = b.With(func->StartTarget());
+
+ auto* ifelse = sb.If(cond);
+ auto tb = b.With(ifelse->True());
+ tb.ExitIf(ifelse, tb.Add(ty.i32(), in, 1_i));
+ auto fb = b.With(ifelse->False());
+ fb.ExitIf(ifelse, fb.Add(ty.i32(), in, 2_i));
+ auto mb = b.With(ifelse->Merge());
+ auto* merge_param = b.BlockParam(ty.i32());
+ ifelse->Merge()->SetParams({merge_param});
+ mb.Return(func, merge_param);
+
+ auto* src = R"(
+%foo = func(%2:i32):i32 -> %b1 {
+ %b1 = block {
+ if %3 [t: %b2, f: %b3, m: %b4]
+ # True block
+ %b2 = block {
+ %4:i32 = add %2, 1i
+ exit_if %b4 %4
+ }
+
+ # False block
+ %b3 = block {
+ %5:i32 = add %2, 2i
+ exit_if %b4 %5
+ }
+
+ # Merge block
+ %b4 = block (%6:i32) {
+ ret %6:i32
+ }
+
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = src;
+
+ Run<MergeReturn>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_MergeReturnTest, NoModify_SingleReturnInNestedMergeBlock) {
+ auto* in = b.FunctionParam(ty.i32());
+ auto* cond = b.FunctionParam(ty.bool_());
+ auto* func = b.Function("foo", ty.i32());
+ func->SetParams({in});
+ mod.functions.Push(func);
+
+ auto sb = b.With(func->StartTarget());
+
+ auto* swtch = sb.Switch(in);
+ b.Case(swtch, {Switch::CaseSelector{}})->Append(b.ExitSwitch(swtch));
+
+ auto* loop = b.Loop();
+ swtch->Merge()->Append(loop);
+
+ auto* ifelse = b.If(cond);
+ loop->Merge()->Append(ifelse);
+ auto tb = b.With(ifelse->True());
+ tb.ExitIf(ifelse, tb.Add(ty.i32(), in, 1_i));
+ auto fb = b.With(ifelse->False());
+ fb.ExitIf(ifelse, fb.Add(ty.i32(), in, 2_i));
+ auto mb = b.With(ifelse->Merge());
+ auto* merge_param = b.BlockParam(ty.i32());
+ ifelse->Merge()->SetParams({merge_param});
+ mb.Return(func, merge_param);
+
+ auto* src = R"(
+%foo = func(%2:i32):i32 -> %b1 {
+ %b1 = block {
+ switch %2 [c: (default, %b2), m: %b3]
+ # Case block
+ %b2 = block {
+ exit_switch %b3
+ }
+
+ # Merge block
+ %b3 = block {
+ loop [m: %b4]
+ # Merge block
+ %b4 = block {
+ if %3 [t: %b5, f: %b6, m: %b7]
+ # True block
+ %b5 = block {
+ %4:i32 = add %2, 1i
+ exit_if %b7 %4
+ }
+
+ # False block
+ %b6 = block {
+ %5:i32 = add %2, 2i
+ exit_if %b7 %5
+ }
+
+ # Merge block
+ %b7 = block (%6:i32) {
+ ret %6:i32
+ }
+
+ }
+
+ }
+
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = src;
+
+ Run<MergeReturn>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_MergeReturnTest, IfElse_OneSideReturns) {
+ auto* cond = b.FunctionParam(ty.bool_());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({cond});
+ mod.functions.Push(func);
+
+ auto sb = b.With(func->StartTarget());
+
+ auto* ifelse = sb.If(cond);
+ auto tb = b.With(ifelse->True());
+ tb.Return(func);
+ auto fb = b.With(ifelse->False());
+ fb.ExitIf(ifelse);
+ auto mb = b.With(ifelse->Merge());
+ mb.Return(func);
+
+ auto* src = R"(
+%foo = func(%2:bool):void -> %b1 {
+ %b1 = block {
+ if %2 [t: %b2, f: %b3, m: %b4]
+ # True block
+ %b2 = block {
+ ret
+ }
+
+ # False block
+ %b3 = block {
+ exit_if %b4
+ }
+
+ # Merge block
+ %b4 = block {
+ ret
+ }
+
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%2:bool):void -> %b1 {
+ %b1 = block {
+ if %2 [t: %b2, f: %b3, m: %b4]
+ # True block
+ %b2 = block {
+ exit_if %b4
+ }
+
+ # False block
+ %b3 = block {
+ exit_if %b4
+ }
+
+ # Merge block
+ %b4 = block {
+ ret
+ }
+
+ }
+}
+)";
+
+ Run<MergeReturn>();
+
+ EXPECT_EQ(expect, str());
+}
+
+// This is the same as the above tests, but we create the return instructions in a different order
+// to make sure that creation order doesn't matter.
+TEST_F(IR_MergeReturnTest, IfElse_OneSideReturns_ReturnsCreatedInDifferentOrder) {
+ auto* cond = b.FunctionParam(ty.bool_());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({cond});
+ mod.functions.Push(func);
+
+ auto sb = b.With(func->StartTarget());
+
+ auto* ifelse = sb.If(cond);
+ auto mb = b.With(ifelse->Merge());
+ mb.Return(func);
+ auto tb = b.With(ifelse->True());
+ tb.Return(func);
+ auto fb = b.With(ifelse->False());
+ fb.ExitIf(ifelse);
+
+ auto* src = R"(
+%foo = func(%2:bool):void -> %b1 {
+ %b1 = block {
+ if %2 [t: %b2, f: %b3, m: %b4]
+ # True block
+ %b2 = block {
+ ret
+ }
+
+ # False block
+ %b3 = block {
+ exit_if %b4
+ }
+
+ # Merge block
+ %b4 = block {
+ ret
+ }
+
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%2:bool):void -> %b1 {
+ %b1 = block {
+ if %2 [t: %b2, f: %b3, m: %b4]
+ # True block
+ %b2 = block {
+ exit_if %b4
+ }
+
+ # False block
+ %b3 = block {
+ exit_if %b4
+ }
+
+ # Merge block
+ %b4 = block {
+ ret
+ }
+
+ }
+}
+)";
+
+ Run<MergeReturn>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_MergeReturnTest, IfElse_OneSideReturns_WithValue) {
+ auto* cond = b.FunctionParam(ty.bool_());
+ auto* func = b.Function("foo", ty.i32());
+ func->SetParams({cond});
+ mod.functions.Push(func);
+
+ auto sb = b.With(func->StartTarget());
+
+ auto* ifelse = sb.If(cond);
+ auto tb = b.With(ifelse->True());
+ tb.Return(func, 1_i);
+ auto fb = b.With(ifelse->False());
+ fb.ExitIf(ifelse);
+ auto mb = b.With(ifelse->Merge());
+ mb.Return(func, 2_i);
+
+ auto* src = R"(
+%foo = func(%2:bool):i32 -> %b1 {
+ %b1 = block {
+ if %2 [t: %b2, f: %b3, m: %b4]
+ # True block
+ %b2 = block {
+ ret 1i
+ }
+
+ # False block
+ %b3 = block {
+ exit_if %b4
+ }
+
+ # Merge block
+ %b4 = block {
+ ret 2i
+ }
+
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%2:bool):i32 -> %b1 {
+ %b1 = block {
+ %return_value:ptr<function, i32, read_write> = var
+ %return_flag:ptr<function, bool, read_write> = var, false
+ if %2 [t: %b2, f: %b3, m: %b4]
+ # True block
+ %b2 = block {
+ store %return_flag, true
+ store %return_value, 1i
+ exit_if %b4
+ }
+
+ # False block
+ %b3 = block {
+ exit_if %b4
+ }
+
+ # Merge block
+ %b4 = block {
+ %5:bool = load %return_flag
+ if %5 [t: %b5, f: %b6, m: %b7]
+ # True block
+ %b5 = block {
+ exit_if %b7
+ }
+
+ # False block
+ %b6 = block {
+ store %return_flag, true
+ store %return_value, 2i
+ exit_if %b7
+ }
+
+ # Merge block
+ %b7 = block {
+ %6:i32 = load %return_value
+ ret %6
+ }
+
+ }
+
+ }
+}
+)";
+
+ Run<MergeReturn>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_MergeReturnTest, IfElse_OneSideReturns_WithValue_MergeHasBasicBlockArguments) {
+ auto* cond = b.FunctionParam(ty.bool_());
+ auto* func = b.Function("foo", ty.i32());
+ func->SetParams({cond});
+ mod.functions.Push(func);
+
+ auto sb = b.With(func->StartTarget());
+
+ auto* ifelse = sb.If(cond);
+ auto tb = b.With(ifelse->True());
+ tb.Return(func, 1_i);
+ auto fb = b.With(ifelse->False());
+ fb.ExitIf(ifelse, 2_i);
+ auto mb = b.With(ifelse->Merge());
+ auto* merge_param = b.BlockParam(ty.i32());
+ ifelse->Merge()->SetParams({merge_param});
+ mb.Return(func, merge_param);
+
+ auto* src = R"(
+%foo = func(%2:bool):i32 -> %b1 {
+ %b1 = block {
+ if %2 [t: %b2, f: %b3, m: %b4]
+ # True block
+ %b2 = block {
+ ret 1i
+ }
+
+ # False block
+ %b3 = block {
+ exit_if %b4 2i
+ }
+
+ # Merge block
+ %b4 = block (%3:i32) {
+ ret %3:i32
+ }
+
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%2:bool):i32 -> %b1 {
+ %b1 = block {
+ %return_value:ptr<function, i32, read_write> = var
+ %return_flag:ptr<function, bool, read_write> = var, false
+ if %2 [t: %b2, f: %b3, m: %b4]
+ # True block
+ %b2 = block {
+ store %return_flag, true
+ store %return_value, 1i
+ exit_if %b4 undef
+ }
+
+ # False block
+ %b3 = block {
+ exit_if %b4 2i
+ }
+
+ # Merge block
+ %b4 = block (%5:i32) {
+ %6:bool = load %return_flag
+ if %6 [t: %b5, f: %b6, m: %b7]
+ # True block
+ %b5 = block {
+ exit_if %b7
+ }
+
+ # False block
+ %b6 = block {
+ store %return_flag, true
+ store %return_value, %5:i32
+ exit_if %b7
+ }
+
+ # Merge block
+ %b7 = block {
+ %7:i32 = load %return_value
+ ret %7
+ }
+
+ }
+
+ }
+}
+)";
+
+ Run<MergeReturn>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_MergeReturnTest, IfElse_OneSideReturns_WithValue_MergeHasUndefBasicBlockArguments) {
+ auto* cond = b.FunctionParam(ty.bool_());
+ auto* func = b.Function("foo", ty.i32());
+ func->SetParams({cond});
+ mod.functions.Push(func);
+
+ auto sb = b.With(func->StartTarget());
+
+ auto* ifelse = sb.If(cond);
+ auto tb = b.With(ifelse->True());
+ tb.Return(func, 1_i);
+ auto fb = b.With(ifelse->False());
+ fb.ExitIf(ifelse, nullptr);
+ auto mb = b.With(ifelse->Merge());
+ auto* merge_param = b.BlockParam(ty.i32());
+ ifelse->Merge()->SetParams({merge_param});
+ mb.Return(func, merge_param);
+
+ auto* src = R"(
+%foo = func(%2:bool):i32 -> %b1 {
+ %b1 = block {
+ if %2 [t: %b2, f: %b3, m: %b4]
+ # True block
+ %b2 = block {
+ ret 1i
+ }
+
+ # False block
+ %b3 = block {
+ exit_if %b4 undef
+ }
+
+ # Merge block
+ %b4 = block (%3:i32) {
+ ret %3:i32
+ }
+
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%2:bool):i32 -> %b1 {
+ %b1 = block {
+ %return_value:ptr<function, i32, read_write> = var
+ %return_flag:ptr<function, bool, read_write> = var, false
+ if %2 [t: %b2, f: %b3, m: %b4]
+ # True block
+ %b2 = block {
+ store %return_flag, true
+ store %return_value, 1i
+ exit_if %b4 undef
+ }
+
+ # False block
+ %b3 = block {
+ exit_if %b4 undef
+ }
+
+ # Merge block
+ %b4 = block (%5:i32) {
+ %6:bool = load %return_flag
+ if %6 [t: %b5, f: %b6, m: %b7]
+ # True block
+ %b5 = block {
+ exit_if %b7
+ }
+
+ # False block
+ %b6 = block {
+ store %return_flag, true
+ store %return_value, %5:i32
+ exit_if %b7
+ }
+
+ # Merge block
+ %b7 = block {
+ %7:i32 = load %return_value
+ ret %7
+ }
+
+ }
+
+ }
+}
+)";
+
+ Run<MergeReturn>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_MergeReturnTest, IfElse_BothSidesReturn) {
+ auto* cond = b.FunctionParam(ty.bool_());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({cond});
+ mod.functions.Push(func);
+
+ auto sb = b.With(func->StartTarget());
+
+ auto* ifelse = sb.If(cond);
+ auto tb = b.With(ifelse->True());
+ tb.Return(func);
+ auto fb = b.With(ifelse->False());
+ fb.Return(func);
+
+ auto* src = R"(
+%foo = func(%2:bool):void -> %b1 {
+ %b1 = block {
+ if %2 [t: %b2, f: %b3]
+ # True block
+ %b2 = block {
+ ret
+ }
+
+ # False block
+ %b3 = block {
+ ret
+ }
+
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%2:bool):void -> %b1 {
+ %b1 = block {
+ if %2 [t: %b2, f: %b3, m: %b4]
+ # True block
+ %b2 = block {
+ exit_if %b4
+ }
+
+ # False block
+ %b3 = block {
+ exit_if %b4
+ }
+
+ # Merge block
+ %b4 = block {
+ ret
+ }
+
+ }
+}
+)";
+
+ Run<MergeReturn>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_MergeReturnTest, IfElse_NonEmptyMergeBlock) {
+ auto* global = b.Var(ty.ptr<private_, i32>());
+ b.RootBlock()->Append(global);
+
+ auto* cond = b.FunctionParam(ty.bool_());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({cond});
+ mod.functions.Push(func);
+
+ auto sb = b.With(func->StartTarget());
+
+ auto* ifelse = sb.If(cond);
+ auto tb = b.With(ifelse->True());
+ tb.Return(func);
+ auto fb = b.With(ifelse->False());
+ fb.ExitIf(ifelse);
+ auto mb = b.With(ifelse->Merge());
+ mb.Store(global, 42_i);
+ mb.Return(func);
+
+ auto* src = R"(
+# Root block
+%b1 = block {
+ %1:ptr<private, i32, read_write> = var
+}
+
+%foo = func(%3:bool):void -> %b2 {
+ %b2 = block {
+ if %3 [t: %b3, f: %b4, m: %b5]
+ # True block
+ %b3 = block {
+ ret
+ }
+
+ # False block
+ %b4 = block {
+ exit_if %b5
+ }
+
+ # Merge block
+ %b5 = block {
+ store %1, 42i
+ ret
+ }
+
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+# Root block
+%b1 = block {
+ %1:ptr<private, i32, read_write> = var
+}
+
+%foo = func(%3:bool):void -> %b2 {
+ %b2 = block {
+ %return_flag:ptr<function, bool, read_write> = var, false
+ if %3 [t: %b3, f: %b4, m: %b5]
+ # True block
+ %b3 = block {
+ store %return_flag, true
+ exit_if %b5
+ }
+
+ # False block
+ %b4 = block {
+ exit_if %b5
+ }
+
+ # Merge block
+ %b5 = block {
+ %5:bool = load %return_flag
+ if %5 [t: %b6, f: %b7, m: %b8]
+ # True block
+ %b6 = block {
+ exit_if %b8
+ }
+
+ # False block
+ %b7 = block {
+ store %1, 42i
+ store %return_flag, true
+ exit_if %b8
+ }
+
+ # Merge block
+ %b8 = block {
+ ret
+ }
+
+ }
+
+ }
+}
+)";
+
+ Run<MergeReturn>();
+
+ EXPECT_EQ(expect, str());
+}
+
+// This is the same as the above tests, but we create the return instructions in a different order
+// to make sure that creation order doesn't matter.
+TEST_F(IR_MergeReturnTest, IfElse_NonEmptyMergeBlock_ReturnsCreatedInDifferentOrder) {
+ auto* global = b.Var(ty.ptr<private_, i32>());
+ b.RootBlock()->Append(global);
+
+ auto* cond = b.FunctionParam(ty.bool_());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({cond});
+ mod.functions.Push(func);
+
+ auto sb = b.With(func->StartTarget());
+
+ auto* ifelse = sb.If(cond);
+ auto mb = b.With(ifelse->Merge());
+ mb.Store(global, 42_i);
+ mb.Return(func);
+ auto tb = b.With(ifelse->True());
+ tb.Return(func);
+ auto fb = b.With(ifelse->False());
+ fb.ExitIf(ifelse);
+
+ auto* src = R"(
+# Root block
+%b1 = block {
+ %1:ptr<private, i32, read_write> = var
+}
+
+%foo = func(%3:bool):void -> %b2 {
+ %b2 = block {
+ if %3 [t: %b3, f: %b4, m: %b5]
+ # True block
+ %b3 = block {
+ ret
+ }
+
+ # False block
+ %b4 = block {
+ exit_if %b5
+ }
+
+ # Merge block
+ %b5 = block {
+ store %1, 42i
+ ret
+ }
+
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+# Root block
+%b1 = block {
+ %1:ptr<private, i32, read_write> = var
+}
+
+%foo = func(%3:bool):void -> %b2 {
+ %b2 = block {
+ %return_flag:ptr<function, bool, read_write> = var, false
+ if %3 [t: %b3, f: %b4, m: %b5]
+ # True block
+ %b3 = block {
+ store %return_flag, true
+ exit_if %b5
+ }
+
+ # False block
+ %b4 = block {
+ exit_if %b5
+ }
+
+ # Merge block
+ %b5 = block {
+ %5:bool = load %return_flag
+ if %5 [t: %b6, f: %b7, m: %b8]
+ # True block
+ %b6 = block {
+ exit_if %b8
+ }
+
+ # False block
+ %b7 = block {
+ store %1, 42i
+ store %return_flag, true
+ exit_if %b8
+ }
+
+ # Merge block
+ %b8 = block {
+ ret
+ }
+
+ }
+
+ }
+}
+)";
+
+ Run<MergeReturn>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_MergeReturnTest, IfElse_Nested) {
+ auto* global = b.Var(ty.ptr<private_, i32>());
+ b.RootBlock()->Append(global);
+
+ auto* func = b.Function("foo", ty.i32());
+ auto* condA = b.FunctionParam(ty.bool_());
+ auto* condB = b.FunctionParam(ty.bool_());
+ auto* condC = b.FunctionParam(ty.bool_());
+ mod.SetName(condA, "condA");
+ mod.SetName(condB, "condB");
+ mod.SetName(condC, "condC");
+ func->SetParams({condA, condB, condC});
+ mod.functions.Push(func);
+
+ auto sb = b.With(func->StartTarget());
+
+ auto* ifelse_outer = sb.If(condA);
+ auto outer_true = b.With(ifelse_outer->True());
+ outer_true.Return(func, 3_i);
+ auto outer_false = b.With(ifelse_outer->False());
+ auto* ifelse_middle = outer_false.If(condB);
+ auto outer_merge = b.With(ifelse_outer->Merge());
+ outer_merge.Store(global, 3_i);
+ outer_merge.Return(func, outer_merge.Add(ty.i32(), 5_i, 6_i));
+
+ auto middle_true = b.With(ifelse_middle->True());
+ auto* ifelse_inner = middle_true.If(condC);
+ auto middle_false = b.With(ifelse_middle->False());
+ middle_false.ExitIf(ifelse_middle);
+ auto middle_merge = b.With(ifelse_middle->Merge());
+ middle_merge.Store(global, 2_i);
+ middle_merge.ExitIf(ifelse_outer);
+
+ auto inner_true = b.With(ifelse_inner->True());
+ inner_true.Return(func, 1_i);
+ auto inner_false = b.With(ifelse_inner->False());
+ inner_false.ExitIf(ifelse_inner);
+ auto inner_merge = b.With(ifelse_inner->Merge());
+ inner_merge.Store(global, 1_i);
+ inner_merge.Return(func, 2_i);
+
+ auto* src = R"(
+# Root block
+%b1 = block {
+ %1:ptr<private, i32, read_write> = var
+}
+
+%foo = func(%condA:bool, %condB:bool, %condC:bool):i32 -> %b2 {
+ %b2 = block {
+ if %condA [t: %b3, f: %b4, m: %b5]
+ # True block
+ %b3 = block {
+ ret 3i
+ }
+
+ # False block
+ %b4 = block {
+ if %condB [t: %b6, f: %b7, m: %b8]
+ # True block
+ %b6 = block {
+ if %condC [t: %b9, f: %b10, m: %b11]
+ # True block
+ %b9 = block {
+ ret 1i
+ }
+
+ # False block
+ %b10 = block {
+ exit_if %b11
+ }
+
+ # Merge block
+ %b11 = block {
+ store %1, 1i
+ ret 2i
+ }
+
+ }
+
+ # False block
+ %b7 = block {
+ exit_if %b8
+ }
+
+ # Merge block
+ %b8 = block {
+ store %1, 2i
+ exit_if %b5
+ }
+
+ }
+
+ # Merge block
+ %b5 = block {
+ store %1, 3i
+ %6:i32 = add 5i, 6i
+ ret %6
+ }
+
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+# Root block
+%b1 = block {
+ %1:ptr<private, i32, read_write> = var
+}
+
+%foo = func(%condA:bool, %condB:bool, %condC:bool):i32 -> %b2 {
+ %b2 = block {
+ %return_value:ptr<function, i32, read_write> = var
+ %return_flag:ptr<function, bool, read_write> = var, false
+ if %condA [t: %b3, f: %b4, m: %b5]
+ # True block
+ %b3 = block {
+ store %return_flag, true
+ store %return_value, 3i
+ exit_if %b5
+ }
+
+ # False block
+ %b4 = block {
+ if %condB [t: %b6, f: %b7, m: %b8]
+ # True block
+ %b6 = block {
+ if %condC [t: %b9, f: %b10, m: %b11]
+ # True block
+ %b9 = block {
+ store %return_flag, true
+ store %return_value, 1i
+ exit_if %b11
+ }
+
+ # False block
+ %b10 = block {
+ exit_if %b11
+ }
+
+ # Merge block
+ %b11 = block {
+ %8:bool = load %return_flag
+ if %8 [t: %b12, f: %b13, m: %b14]
+ # True block
+ %b12 = block {
+ exit_if %b14
+ }
+
+ # False block
+ %b13 = block {
+ store %1, 1i
+ store %return_flag, true
+ store %return_value, 2i
+ exit_if %b14
+ }
+
+ # Merge block
+ %b14 = block {
+ exit_if %b8
+ }
+
+ }
+
+ }
+
+ # False block
+ %b7 = block {
+ exit_if %b8
+ }
+
+ # Merge block
+ %b8 = block {
+ %9:bool = load %return_flag
+ if %9 [t: %b15, f: %b16, m: %b17]
+ # True block
+ %b15 = block {
+ exit_if %b17
+ }
+
+ # False block
+ %b16 = block {
+ store %1, 2i
+ exit_if %b17
+ }
+
+ # Merge block
+ %b17 = block {
+ exit_if %b5
+ }
+
+ }
+
+ }
+
+ # Merge block
+ %b5 = block {
+ %10:bool = load %return_flag
+ if %10 [t: %b18, f: %b19, m: %b20]
+ # True block
+ %b18 = block {
+ exit_if %b20
+ }
+
+ # False block
+ %b19 = block {
+ store %1, 3i
+ %11:i32 = add 5i, 6i
+ store %return_flag, true
+ store %return_value, %11
+ exit_if %b20
+ }
+
+ # Merge block
+ %b20 = block {
+ %12:i32 = load %return_value
+ ret %12
+ }
+
+ }
+
+ }
+}
+)";
+
+ Run<MergeReturn>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_MergeReturnTest, IfElse_Nested_TrivialMerge) {
+ auto* global = b.Var(ty.ptr<private_, i32>());
+ b.RootBlock()->Append(global);
+
+ auto* func = b.Function("foo", ty.i32());
+ auto* condA = b.FunctionParam(ty.bool_());
+ auto* condB = b.FunctionParam(ty.bool_());
+ auto* condC = b.FunctionParam(ty.bool_());
+ mod.SetName(condA, "condA");
+ mod.SetName(condB, "condB");
+ mod.SetName(condC, "condC");
+ func->SetParams({condA, condB, condC});
+ mod.functions.Push(func);
+
+ auto sb = b.With(func->StartTarget());
+
+ auto* ifelse_outer = sb.If(condA);
+ auto outer_true = b.With(ifelse_outer->True());
+ outer_true.Return(func, 3_i);
+ auto outer_false = b.With(ifelse_outer->False());
+ auto* ifelse_middle = outer_false.If(condB);
+ auto outer_merge = b.With(ifelse_outer->Merge());
+ outer_merge.Return(func, 3_i);
+
+ auto middle_true = b.With(ifelse_middle->True());
+ auto* ifelse_inner = middle_true.If(condC);
+ auto middle_false = b.With(ifelse_middle->False());
+ middle_false.ExitIf(ifelse_middle);
+ auto middle_merge = b.With(ifelse_middle->Merge());
+ middle_merge.ExitIf(ifelse_outer);
+
+ auto inner_true = b.With(ifelse_inner->True());
+ inner_true.Return(func, 1_i);
+ auto inner_false = b.With(ifelse_inner->False());
+ inner_false.ExitIf(ifelse_inner);
+ auto inner_merge = b.With(ifelse_inner->Merge());
+ inner_merge.ExitIf(ifelse_middle);
+
+ auto* src = R"(
+# Root block
+%b1 = block {
+ %1:ptr<private, i32, read_write> = var
+}
+
+%foo = func(%condA:bool, %condB:bool, %condC:bool):i32 -> %b2 {
+ %b2 = block {
+ if %condA [t: %b3, f: %b4, m: %b5]
+ # True block
+ %b3 = block {
+ ret 3i
+ }
+
+ # False block
+ %b4 = block {
+ if %condB [t: %b6, f: %b7, m: %b8]
+ # True block
+ %b6 = block {
+ if %condC [t: %b9, f: %b10, m: %b11]
+ # True block
+ %b9 = block {
+ ret 1i
+ }
+
+ # False block
+ %b10 = block {
+ exit_if %b11
+ }
+
+ # Merge block
+ %b11 = block {
+ exit_if %b8
+ }
+
+ }
+
+ # False block
+ %b7 = block {
+ exit_if %b8
+ }
+
+ # Merge block
+ %b8 = block {
+ exit_if %b5
+ }
+
+ }
+
+ # Merge block
+ %b5 = block {
+ ret 3i
+ }
+
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+# Root block
+%b1 = block {
+ %1:ptr<private, i32, read_write> = var
+}
+
+%foo = func(%condA:bool, %condB:bool, %condC:bool):i32 -> %b2 {
+ %b2 = block {
+ %return_value:ptr<function, i32, read_write> = var
+ %return_flag:ptr<function, bool, read_write> = var, false
+ if %condA [t: %b3, f: %b4, m: %b5]
+ # True block
+ %b3 = block {
+ store %return_flag, true
+ store %return_value, 3i
+ exit_if %b5
+ }
+
+ # False block
+ %b4 = block {
+ if %condB [t: %b6, f: %b7, m: %b8]
+ # True block
+ %b6 = block {
+ if %condC [t: %b9, f: %b10, m: %b11]
+ # True block
+ %b9 = block {
+ store %return_flag, true
+ store %return_value, 1i
+ exit_if %b11
+ }
+
+ # False block
+ %b10 = block {
+ exit_if %b11
+ }
+
+ # Merge block
+ %b11 = block {
+ exit_if %b8
+ }
+
+ }
+
+ # False block
+ %b7 = block {
+ exit_if %b8
+ }
+
+ # Merge block
+ %b8 = block {
+ exit_if %b5
+ }
+
+ }
+
+ # Merge block
+ %b5 = block {
+ %8:bool = load %return_flag
+ if %8 [t: %b12, f: %b13, m: %b14]
+ # True block
+ %b12 = block {
+ exit_if %b14
+ }
+
+ # False block
+ %b13 = block {
+ store %return_flag, true
+ store %return_value, 3i
+ exit_if %b14
+ }
+
+ # Merge block
+ %b14 = block {
+ %9:i32 = load %return_value
+ ret %9
+ }
+
+ }
+
+ }
+}
+)";
+
+ Run<MergeReturn>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_MergeReturnTest, IfElse_Nested_WithBasicBlockArguments) {
+ auto* global = b.Var(ty.ptr<private_, i32>());
+ b.RootBlock()->Append(global);
+
+ auto* func = b.Function("foo", ty.i32());
+ auto* condA = b.FunctionParam(ty.bool_());
+ auto* condB = b.FunctionParam(ty.bool_());
+ auto* condC = b.FunctionParam(ty.bool_());
+ mod.SetName(condA, "condA");
+ mod.SetName(condB, "condB");
+ mod.SetName(condC, "condC");
+ func->SetParams({condA, condB, condC});
+ mod.functions.Push(func);
+
+ auto sb = b.With(func->StartTarget());
+
+ auto* ifelse_outer = sb.If(condA);
+ auto outer_true = b.With(ifelse_outer->True());
+ outer_true.Return(func, 3_i);
+ auto outer_false = b.With(ifelse_outer->False());
+ auto* ifelse_middle = outer_false.If(condB);
+ auto outer_merge = b.With(ifelse_outer->Merge());
+ auto* outer_param = b.BlockParam(ty.i32());
+ ifelse_outer->Merge()->SetParams({outer_param});
+ outer_merge.Return(func, outer_merge.Add(ty.i32(), outer_param, 1_i));
+
+ auto middle_true = b.With(ifelse_middle->True());
+ auto* ifelse_inner = middle_true.If(condC);
+ auto middle_false = b.With(ifelse_middle->False());
+ middle_false.ExitIf(ifelse_middle, middle_false.Add(ty.i32(), 43_i, 2_i));
+ auto middle_merge = b.With(ifelse_middle->Merge());
+ auto* middle_param = b.BlockParam(ty.i32());
+ ifelse_middle->Merge()->SetParams({middle_param});
+ middle_merge.ExitIf(ifelse_outer, middle_merge.Add(ty.i32(), middle_param, 1_i));
+
+ auto inner_true = b.With(ifelse_inner->True());
+ inner_true.Return(func, 1_i);
+ auto inner_false = b.With(ifelse_inner->False());
+ inner_false.ExitIf(ifelse_inner);
+ auto inner_merge = b.With(ifelse_inner->Merge());
+ inner_merge.ExitIf(ifelse_middle, inner_merge.Add(ty.i32(), 42_i, 1_i));
+
+ auto* src = R"(
+# Root block
+%b1 = block {
+ %1:ptr<private, i32, read_write> = var
+}
+
+%foo = func(%condA:bool, %condB:bool, %condC:bool):i32 -> %b2 {
+ %b2 = block {
+ if %condA [t: %b3, f: %b4, m: %b5]
+ # True block
+ %b3 = block {
+ ret 3i
+ }
+
+ # False block
+ %b4 = block {
+ if %condB [t: %b6, f: %b7, m: %b8]
+ # True block
+ %b6 = block {
+ if %condC [t: %b9, f: %b10, m: %b11]
+ # True block
+ %b9 = block {
+ ret 1i
+ }
+
+ # False block
+ %b10 = block {
+ exit_if %b11
+ }
+
+ # Merge block
+ %b11 = block {
+ %6:i32 = add 42i, 1i
+ exit_if %b8 %6
+ }
+
+ }
+
+ # False block
+ %b7 = block {
+ %7:i32 = add 43i, 2i
+ exit_if %b8 %7
+ }
+
+ # Merge block
+ %b8 = block (%8:i32) {
+ %9:i32 = add %8:i32, 1i
+ exit_if %b5 %9
+ }
+
+ }
+
+ # Merge block
+ %b5 = block (%10:i32) {
+ %11:i32 = add %10:i32, 1i
+ ret %11
+ }
+
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+# Root block
+%b1 = block {
+ %1:ptr<private, i32, read_write> = var
+}
+
+%foo = func(%condA:bool, %condB:bool, %condC:bool):i32 -> %b2 {
+ %b2 = block {
+ %return_value:ptr<function, i32, read_write> = var
+ %return_flag:ptr<function, bool, read_write> = var, false
+ if %condA [t: %b3, f: %b4, m: %b5]
+ # True block
+ %b3 = block {
+ store %return_flag, true
+ store %return_value, 3i
+ exit_if %b5 undef
+ }
+
+ # False block
+ %b4 = block {
+ if %condB [t: %b6, f: %b7, m: %b8]
+ # True block
+ %b6 = block {
+ if %condC [t: %b9, f: %b10, m: %b11]
+ # True block
+ %b9 = block {
+ store %return_flag, true
+ store %return_value, 1i
+ exit_if %b11
+ }
+
+ # False block
+ %b10 = block {
+ exit_if %b11
+ }
+
+ # Merge block
+ %b11 = block {
+ %8:bool = load %return_flag
+ if %8 [t: %b12, f: %b13, m: %b14]
+ # True block
+ %b12 = block {
+ exit_if %b14 undef
+ }
+
+ # False block
+ %b13 = block {
+ %9:i32 = add 42i, 1i
+ exit_if %b14 %9
+ }
+
+ # Merge block
+ %b14 = block (%10:i32) {
+ exit_if %b8 %10:i32
+ }
+
+ }
+
+ }
+
+ # False block
+ %b7 = block {
+ %11:i32 = add 43i, 2i
+ exit_if %b8 %11
+ }
+
+ # Merge block
+ %b8 = block (%12:i32) {
+ %13:bool = load %return_flag
+ if %13 [t: %b15, f: %b16, m: %b17]
+ # True block
+ %b15 = block {
+ exit_if %b17 undef
+ }
+
+ # False block
+ %b16 = block {
+ %14:i32 = add %12:i32, 1i
+ exit_if %b17 %14
+ }
+
+ # Merge block
+ %b17 = block (%15:i32) {
+ exit_if %b5 %15:i32
+ }
+
+ }
+
+ }
+
+ # Merge block
+ %b5 = block (%16:i32) {
+ %17:bool = load %return_flag
+ if %17 [t: %b18, f: %b19, m: %b20]
+ # True block
+ %b18 = block {
+ exit_if %b20
+ }
+
+ # False block
+ %b19 = block {
+ %18:i32 = add %16:i32, 1i
+ store %return_flag, true
+ store %return_value, %18
+ exit_if %b20
+ }
+
+ # Merge block
+ %b20 = block {
+ %19:i32 = load %return_value
+ ret %19
+ }
+
+ }
+
+ }
+}
+)";
+
+ Run<MergeReturn>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_MergeReturnTest, Loop_UnconditionalReturnInBody) {
+ auto* func = b.Function("foo", ty.i32());
+ mod.functions.Push(func);
+
+ auto sb = b.With(func->StartTarget());
+
+ auto* loop = sb.Loop();
+ loop->Body()->Append(b.Return(func, 42_i));
+
+ auto* src = R"(
+%foo = func():i32 -> %b1 {
+ %b1 = block {
+ loop [b: %b2]
+ # Body block
+ %b2 = block {
+ ret 42i
+ }
+
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func():i32 -> %b1 {
+ %b1 = block {
+ %return_value:ptr<function, i32, read_write> = var
+ loop [b: %b2, m: %b3]
+ # Body block
+ %b2 = block {
+ store %return_value, 42i
+ exit_loop %b3
+ }
+
+ # Merge block
+ %b3 = block {
+ %3:i32 = load %return_value
+ ret %3
+ }
+
+ }
+}
+)";
+
+ Run<MergeReturn>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_MergeReturnTest, Loop_ConditionalReturnInBody) {
+ auto* global = b.Var(ty.ptr<private_, i32>());
+ b.RootBlock()->Append(global);
+
+ auto* cond = b.FunctionParam(ty.bool_());
+ auto* func = b.Function("foo", ty.i32());
+ func->SetParams({cond});
+ mod.functions.Push(func);
+
+ auto sb = b.With(func->StartTarget());
+
+ auto* loop = sb.Loop();
+ auto lb = b.With(loop->Body());
+ auto* ifelse = lb.If(cond);
+ {
+ auto tb = b.With(ifelse->True());
+ tb.Return(func, 42_i);
+ auto fb = b.With(ifelse->False());
+ fb.ExitIf(ifelse);
+ auto mb = b.With(ifelse->Merge());
+ mb.Store(global, 2_i);
+ mb.Continue(loop);
+ }
+ auto cb = b.With(loop->Continuing());
+ cb.Store(global, 1_i);
+ cb.BreakIf(true, loop);
+ auto mb = b.With(loop->Merge());
+ mb.Store(global, 3_i);
+ mb.Return(func, 43_i);
+
+ auto* src = R"(
+# Root block
+%b1 = block {
+ %1:ptr<private, i32, read_write> = var
+}
+
+%foo = func(%3:bool):i32 -> %b2 {
+ %b2 = block {
+ loop [b: %b3, c: %b4, m: %b5]
+ # Body block
+ %b3 = block {
+ if %3 [t: %b6, f: %b7, m: %b8]
+ # True block
+ %b6 = block {
+ ret 42i
+ }
+
+ # False block
+ %b7 = block {
+ exit_if %b8
+ }
+
+ # Merge block
+ %b8 = block {
+ store %1, 2i
+ continue %b4
+ }
+
+ }
+
+ # Continuing block
+ %b4 = block {
+ store %1, 1i
+ break_if true %b3
+ }
+
+ # Merge block
+ %b5 = block {
+ store %1, 3i
+ ret 43i
+ }
+
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+# Root block
+%b1 = block {
+ %1:ptr<private, i32, read_write> = var
+}
+
+%foo = func(%3:bool):i32 -> %b2 {
+ %b2 = block {
+ %return_value:ptr<function, i32, read_write> = var
+ %return_flag:ptr<function, bool, read_write> = var, false
+ loop [b: %b3, c: %b4, m: %b5]
+ # Body block
+ %b3 = block {
+ if %3 [t: %b6, f: %b7, m: %b8]
+ # True block
+ %b6 = block {
+ store %return_flag, true
+ store %return_value, 42i
+ exit_if %b8
+ }
+
+ # False block
+ %b7 = block {
+ exit_if %b8
+ }
+
+ # Merge block
+ %b8 = block {
+ %6:bool = load %return_flag
+ if %6 [t: %b9, f: %b10, m: %b11]
+ # True block
+ %b9 = block {
+ exit_if %b11
+ }
+
+ # False block
+ %b10 = block {
+ store %1, 2i
+ continue %b4
+ }
+
+ # Merge block
+ %b11 = block {
+ exit_loop %b5
+ }
+
+ }
+
+ }
+
+ # Continuing block
+ %b4 = block {
+ store %1, 1i
+ break_if true %b3
+ }
+
+ # Merge block
+ %b5 = block {
+ %7:bool = load %return_flag
+ if %7 [t: %b12, f: %b13, m: %b14]
+ # True block
+ %b12 = block {
+ exit_if %b14
+ }
+
+ # False block
+ %b13 = block {
+ store %1, 3i
+ store %return_flag, true
+ store %return_value, 43i
+ exit_if %b14
+ }
+
+ # Merge block
+ %b14 = block {
+ %8:i32 = load %return_value
+ ret %8
+ }
+
+ }
+
+ }
+}
+)";
+
+ Run<MergeReturn>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_MergeReturnTest, Loop_ConditionalReturnInBody_UnreachableMerge) {
+ auto* global = b.Var(ty.ptr<private_, i32>());
+ b.RootBlock()->Append(global);
+
+ auto* cond = b.FunctionParam(ty.bool_());
+ auto* func = b.Function("foo", ty.i32());
+ func->SetParams({cond});
+ mod.functions.Push(func);
+
+ auto sb = b.With(func->StartTarget());
+
+ auto* loop = sb.Loop();
+ auto lb = b.With(loop->Body());
+ auto* ifelse = lb.If(cond);
+ {
+ auto tb = b.With(ifelse->True());
+ tb.Return(func, 42_i);
+ auto fb = b.With(ifelse->False());
+ fb.ExitIf(ifelse);
+ auto mb = b.With(ifelse->Merge());
+ mb.Store(global, 2_i);
+ mb.Continue(loop);
+ }
+ auto cb = b.With(loop->Continuing());
+ cb.Store(global, 1_i);
+ cb.NextIteration(loop);
+
+ auto* src = R"(
+# Root block
+%b1 = block {
+ %1:ptr<private, i32, read_write> = var
+}
+
+%foo = func(%3:bool):i32 -> %b2 {
+ %b2 = block {
+ loop [b: %b3, c: %b4]
+ # Body block
+ %b3 = block {
+ if %3 [t: %b5, f: %b6, m: %b7]
+ # True block
+ %b5 = block {
+ ret 42i
+ }
+
+ # False block
+ %b6 = block {
+ exit_if %b7
+ }
+
+ # Merge block
+ %b7 = block {
+ store %1, 2i
+ continue %b4
+ }
+
+ }
+
+ # Continuing block
+ %b4 = block {
+ store %1, 1i
+ next_iteration %b3
+ }
+
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+# Root block
+%b1 = block {
+ %1:ptr<private, i32, read_write> = var
+}
+
+%foo = func(%3:bool):i32 -> %b2 {
+ %b2 = block {
+ %return_value:ptr<function, i32, read_write> = var
+ %return_flag:ptr<function, bool, read_write> = var, false
+ loop [b: %b3, c: %b4, m: %b5]
+ # Body block
+ %b3 = block {
+ if %3 [t: %b6, f: %b7, m: %b8]
+ # True block
+ %b6 = block {
+ store %return_flag, true
+ store %return_value, 42i
+ exit_if %b8
+ }
+
+ # False block
+ %b7 = block {
+ exit_if %b8
+ }
+
+ # Merge block
+ %b8 = block {
+ %6:bool = load %return_flag
+ if %6 [t: %b9, f: %b10, m: %b11]
+ # True block
+ %b9 = block {
+ exit_if %b11
+ }
+
+ # False block
+ %b10 = block {
+ store %1, 2i
+ continue %b4
+ }
+
+ # Merge block
+ %b11 = block {
+ exit_loop %b5
+ }
+
+ }
+
+ }
+
+ # Continuing block
+ %b4 = block {
+ store %1, 1i
+ next_iteration %b3
+ }
+
+ # Merge block
+ %b5 = block {
+ %7:i32 = load %return_value
+ ret %7
+ }
+
+ }
+}
+)";
+
+ Run<MergeReturn>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_MergeReturnTest, Loop_WithBasicBlockArgumentsOnMerge) {
+ auto* global = b.Var(ty.ptr<private_, i32>());
+ b.RootBlock()->Append(global);
+
+ auto* cond = b.FunctionParam(ty.bool_());
+ auto* func = b.Function("foo", ty.i32());
+ func->SetParams({cond});
+ mod.functions.Push(func);
+
+ auto sb = b.With(func->StartTarget());
+
+ auto* loop = sb.Loop();
+ auto lb = b.With(loop->Body());
+ auto* ifelse = lb.If(cond);
+ {
+ auto tb = b.With(ifelse->True());
+ tb.Return(func, 42_i);
+ auto fb = b.With(ifelse->False());
+ fb.ExitIf(ifelse);
+ auto mb = b.With(ifelse->Merge());
+ mb.Store(global, 2_i);
+ mb.Continue(loop);
+ }
+ auto cb = b.With(loop->Continuing());
+ cb.Store(global, 1_i);
+ cb.BreakIf(true, loop, 4_i);
+ auto* merge_param = b.BlockParam(ty.i32());
+ auto mb = b.With(loop->Merge());
+ loop->Merge()->SetParams({merge_param});
+ mb.Store(global, 3_i);
+ mb.Return(func, merge_param);
+
+ auto* src = R"(
+# Root block
+%b1 = block {
+ %1:ptr<private, i32, read_write> = var
+}
+
+%foo = func(%3:bool):i32 -> %b2 {
+ %b2 = block {
+ loop [b: %b3, c: %b4, m: %b5]
+ # Body block
+ %b3 = block {
+ if %3 [t: %b6, f: %b7, m: %b8]
+ # True block
+ %b6 = block {
+ ret 42i
+ }
+
+ # False block
+ %b7 = block {
+ exit_if %b8
+ }
+
+ # Merge block
+ %b8 = block {
+ store %1, 2i
+ continue %b4
+ }
+
+ }
+
+ # Continuing block
+ %b4 = block {
+ store %1, 1i
+ break_if true %b3 4i
+ }
+
+ # Merge block
+ %b5 = block (%4:i32) {
+ store %1, 3i
+ ret %4:i32
+ }
+
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+# Root block
+%b1 = block {
+ %1:ptr<private, i32, read_write> = var
+}
+
+%foo = func(%3:bool):i32 -> %b2 {
+ %b2 = block {
+ %return_value:ptr<function, i32, read_write> = var
+ %return_flag:ptr<function, bool, read_write> = var, false
+ loop [b: %b3, c: %b4, m: %b5]
+ # Body block
+ %b3 = block {
+ if %3 [t: %b6, f: %b7, m: %b8]
+ # True block
+ %b6 = block {
+ store %return_flag, true
+ store %return_value, 42i
+ exit_if %b8
+ }
+
+ # False block
+ %b7 = block {
+ exit_if %b8
+ }
+
+ # Merge block
+ %b8 = block {
+ %6:bool = load %return_flag
+ if %6 [t: %b9, f: %b10, m: %b11]
+ # True block
+ %b9 = block {
+ exit_if %b11
+ }
+
+ # False block
+ %b10 = block {
+ store %1, 2i
+ continue %b4
+ }
+
+ # Merge block
+ %b11 = block {
+ exit_loop %b5 undef
+ }
+
+ }
+
+ }
+
+ # Continuing block
+ %b4 = block {
+ store %1, 1i
+ break_if true %b3 4i
+ }
+
+ # Merge block
+ %b5 = block (%7:i32) {
+ %8:bool = load %return_flag
+ if %8 [t: %b12, f: %b13, m: %b14]
+ # True block
+ %b12 = block {
+ exit_if %b14
+ }
+
+ # False block
+ %b13 = block {
+ store %1, 3i
+ store %return_flag, true
+ store %return_value, %7:i32
+ exit_if %b14
+ }
+
+ # Merge block
+ %b14 = block {
+ %9:i32 = load %return_value
+ ret %9
+ }
+
+ }
+
+ }
+}
+)";
+
+ Run<MergeReturn>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_MergeReturnTest, Switch_UnconditionalReturnInCase) {
+ auto* cond = b.FunctionParam(ty.i32());
+ auto* func = b.Function("foo", ty.i32());
+ func->SetParams({cond});
+ mod.functions.Push(func);
+
+ auto sb = b.With(func->StartTarget());
+
+ auto* sw = sb.Switch(cond);
+ auto caseA = b.With(b.Case(sw, {Switch::CaseSelector{b.Constant(1_i)}}));
+ caseA.Return(func, 42_i);
+ auto caseB = b.With(b.Case(sw, {Switch::CaseSelector{}}));
+ caseB.ExitSwitch(sw);
+ auto mb = b.With(sw->Merge());
+ mb.Return(func, 0_i);
+
+ auto* src = R"(
+%foo = func(%2:i32):i32 -> %b1 {
+ %b1 = block {
+ switch %2 [c: (1i, %b2), c: (default, %b3), m: %b4]
+ # Case block
+ %b2 = block {
+ ret 42i
+ }
+
+ # Case block
+ %b3 = block {
+ exit_switch %b4
+ }
+
+ # Merge block
+ %b4 = block {
+ ret 0i
+ }
+
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%2:i32):i32 -> %b1 {
+ %b1 = block {
+ %return_value:ptr<function, i32, read_write> = var
+ %return_flag:ptr<function, bool, read_write> = var, false
+ switch %2 [c: (1i, %b2), c: (default, %b3), m: %b4]
+ # Case block
+ %b2 = block {
+ store %return_flag, true
+ store %return_value, 42i
+ exit_switch %b4
+ }
+
+ # Case block
+ %b3 = block {
+ exit_switch %b4
+ }
+
+ # Merge block
+ %b4 = block {
+ %5:bool = load %return_flag
+ if %5 [t: %b5, f: %b6, m: %b7]
+ # True block
+ %b5 = block {
+ exit_if %b7
+ }
+
+ # False block
+ %b6 = block {
+ store %return_flag, true
+ store %return_value, 0i
+ exit_if %b7
+ }
+
+ # Merge block
+ %b7 = block {
+ %6:i32 = load %return_value
+ ret %6
+ }
+
+ }
+
+ }
+}
+)";
+
+ Run<MergeReturn>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_MergeReturnTest, Switch_ConditionalReturnInBody) {
+ auto* global = b.Var(ty.ptr<private_, i32>());
+ b.RootBlock()->Append(global);
+
+ auto* cond = b.FunctionParam(ty.i32());
+ auto* func = b.Function("foo", ty.i32());
+ func->SetParams({cond});
+ mod.functions.Push(func);
+
+ auto sb = b.With(func->StartTarget());
+
+ auto* sw = sb.Switch(cond);
+ auto caseA = b.With(b.Case(sw, {Switch::CaseSelector{b.Constant(1_i)}}));
+ auto* ifelse = caseA.If(cond);
+ {
+ auto tb = b.With(ifelse->True());
+ tb.Return(func, 42_i);
+ auto fb = b.With(ifelse->False());
+ fb.ExitIf(ifelse);
+ auto mb = b.With(ifelse->Merge());
+ mb.Store(global, 2_i);
+ mb.ExitSwitch(sw);
+ }
+ auto caseB = b.With(b.Case(sw, {Switch::CaseSelector{}}));
+ caseB.ExitSwitch(sw);
+
+ auto mb = b.With(sw->Merge());
+ mb.Return(func, 0_i);
+
+ auto* src = R"(
+# Root block
+%b1 = block {
+ %1:ptr<private, i32, read_write> = var
+}
+
+%foo = func(%3:i32):i32 -> %b2 {
+ %b2 = block {
+ switch %3 [c: (1i, %b3), c: (default, %b4), m: %b5]
+ # Case block
+ %b3 = block {
+ if %3 [t: %b6, f: %b7, m: %b8]
+ # True block
+ %b6 = block {
+ ret 42i
+ }
+
+ # False block
+ %b7 = block {
+ exit_if %b8
+ }
+
+ # Merge block
+ %b8 = block {
+ store %1, 2i
+ exit_switch %b5
+ }
+
+ }
+
+ # Case block
+ %b4 = block {
+ exit_switch %b5
+ }
+
+ # Merge block
+ %b5 = block {
+ ret 0i
+ }
+
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+# Root block
+%b1 = block {
+ %1:ptr<private, i32, read_write> = var
+}
+
+%foo = func(%3:i32):i32 -> %b2 {
+ %b2 = block {
+ %return_value:ptr<function, i32, read_write> = var
+ %return_flag:ptr<function, bool, read_write> = var, false
+ switch %3 [c: (1i, %b3), c: (default, %b4), m: %b5]
+ # Case block
+ %b3 = block {
+ if %3 [t: %b6, f: %b7, m: %b8]
+ # True block
+ %b6 = block {
+ store %return_flag, true
+ store %return_value, 42i
+ exit_if %b8
+ }
+
+ # False block
+ %b7 = block {
+ exit_if %b8
+ }
+
+ # Merge block
+ %b8 = block {
+ %6:bool = load %return_flag
+ if %6 [t: %b9, f: %b10, m: %b11]
+ # True block
+ %b9 = block {
+ exit_if %b11
+ }
+
+ # False block
+ %b10 = block {
+ store %1, 2i
+ exit_switch %b5
+ }
+
+ # Merge block
+ %b11 = block {
+ exit_switch %b5
+ }
+
+ }
+
+ }
+
+ # Case block
+ %b4 = block {
+ exit_switch %b5
+ }
+
+ # Merge block
+ %b5 = block {
+ %7:bool = load %return_flag
+ if %7 [t: %b12, f: %b13, m: %b14]
+ # True block
+ %b12 = block {
+ exit_if %b14
+ }
+
+ # False block
+ %b13 = block {
+ store %return_flag, true
+ store %return_value, 0i
+ exit_if %b14
+ }
+
+ # Merge block
+ %b14 = block {
+ %8:i32 = load %return_value
+ ret %8
+ }
+
+ }
+
+ }
+}
+)";
+
+ Run<MergeReturn>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_MergeReturnTest, Switch_WithBasicBlockArgumentsOnMerge) {
+ auto* cond = b.FunctionParam(ty.i32());
+ auto* func = b.Function("foo", ty.i32());
+ func->SetParams({cond});
+ mod.functions.Push(func);
+
+ auto sb = b.With(func->StartTarget());
+
+ auto* sw = sb.Switch(cond);
+ auto caseA = b.With(b.Case(sw, {Switch::CaseSelector{b.Constant(1_i)}}));
+ caseA.Return(func, 42_i);
+ auto caseB = b.With(b.Case(sw, {Switch::CaseSelector{b.Constant(2_i)}}));
+ caseB.Return(func, 99_i);
+ auto caseC = b.With(b.Case(sw, {Switch::CaseSelector{b.Constant(3_i)}}));
+ caseC.ExitSwitch(sw, 1_i);
+ auto caseD = b.With(b.Case(sw, {Switch::CaseSelector{}}));
+ caseD.ExitSwitch(sw, 0_i);
+
+ auto* merge_param = b.BlockParam(ty.i32());
+ auto mb = b.With(sw->Merge());
+ sw->Merge()->SetParams({merge_param});
+ mb.Return(func, merge_param);
+
+ auto* src = R"(
+%foo = func(%2:i32):i32 -> %b1 {
+ %b1 = block {
+ switch %2 [c: (1i, %b2), c: (2i, %b3), c: (3i, %b4), c: (default, %b5), m: %b6]
+ # Case block
+ %b2 = block {
+ ret 42i
+ }
+
+ # Case block
+ %b3 = block {
+ ret 99i
+ }
+
+ # Case block
+ %b4 = block {
+ exit_switch %b6 1i
+ }
+
+ # Case block
+ %b5 = block {
+ exit_switch %b6 0i
+ }
+
+ # Merge block
+ %b6 = block (%3:i32) {
+ ret %3:i32
+ }
+
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%2:i32):i32 -> %b1 {
+ %b1 = block {
+ %return_value:ptr<function, i32, read_write> = var
+ %return_flag:ptr<function, bool, read_write> = var, false
+ switch %2 [c: (1i, %b2), c: (2i, %b3), c: (3i, %b4), c: (default, %b5), m: %b6]
+ # Case block
+ %b2 = block {
+ store %return_flag, true
+ store %return_value, 42i
+ exit_switch %b6 undef
+ }
+
+ # Case block
+ %b3 = block {
+ store %return_flag, true
+ store %return_value, 99i
+ exit_switch %b6 undef
+ }
+
+ # Case block
+ %b4 = block {
+ exit_switch %b6 1i
+ }
+
+ # Case block
+ %b5 = block {
+ exit_switch %b6 0i
+ }
+
+ # Merge block
+ %b6 = block (%5:i32) {
+ %6:bool = load %return_flag
+ if %6 [t: %b7, f: %b8, m: %b9]
+ # True block
+ %b7 = block {
+ exit_if %b9
+ }
+
+ # False block
+ %b8 = block {
+ store %return_flag, true
+ store %return_value, %5:i32
+ exit_if %b9
+ }
+
+ # Merge block
+ %b9 = block {
+ %7:i32 = load %return_value
+ ret %7
+ }
+
+ }
+
+ }
+}
+)";
+
+ Run<MergeReturn>();
+
+ EXPECT_EQ(expect, str());
+}
+
+} // namespace
+} // namespace tint::ir::transform
diff --git a/src/tint/ir/transform/transform.h b/src/tint/ir/transform/transform.h
index a09fe77..e95a7b7 100644
--- a/src/tint/ir/transform/transform.h
+++ b/src/tint/ir/transform/transform.h
@@ -19,6 +19,7 @@
#include <utility>
+#include "src/tint/builtin/address_space.h"
#include "src/tint/utils/castable.h"
// Forward declarations
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc
index 3aecc75..9aa5b9d 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc
@@ -41,6 +41,7 @@
#include "src/tint/ir/switch.h"
#include "src/tint/ir/transform/add_empty_entry_point.h"
#include "src/tint/ir/transform/block_decorated_structs.h"
+#include "src/tint/ir/transform/merge_return.h"
#include "src/tint/ir/transform/var_for_dynamic_index.h"
#include "src/tint/ir/user_call.h"
#include "src/tint/ir/validate.h"
@@ -72,6 +73,7 @@
manager.Add<ir::transform::AddEmptyEntryPoint>();
manager.Add<ir::transform::BlockDecoratedStructs>();
+ manager.Add<ir::transform::MergeReturn>();
manager.Add<ir::transform::VarForDynamicIndex>();
transform::DataMap outputs;