[ir][msl] Add support for `if` statements

Adds support for emitting `if`, `if-else` and `if-elseif` statements
from the MSL IR generator.

Bug: tint:1967
Change-Id: I8c4ff5bfe5a9505ca1f1c7ced4ee71c1d6d5c108
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/144265
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 371d351..5abb8d1 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -2329,6 +2329,7 @@
         "lang/msl/writer/printer/constant_test.cc",
         "lang/msl/writer/printer/function_test.cc",
         "lang/msl/writer/printer/helper_test.h",
+        "lang/msl/writer/printer/if_test.cc",
         "lang/msl/writer/printer/return_test.cc",
         "lang/msl/writer/printer/type_test.cc",
       ]
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index f1dbdaf..92b5621 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -1556,6 +1556,7 @@
         lang/msl/writer/printer/constant_test.cc
         lang/msl/writer/printer/function_test.cc
         lang/msl/writer/printer/helper_test.h
+        lang/msl/writer/printer/if_test.cc
         lang/msl/writer/printer/return_test.cc
         lang/msl/writer/printer/type_test.cc
       )
diff --git a/src/tint/lang/msl/writer/printer/helper_test.h b/src/tint/lang/msl/writer/printer/helper_test.h
index d6c9f8b..152cf3b 100644
--- a/src/tint/lang/msl/writer/printer/helper_test.h
+++ b/src/tint/lang/msl/writer/printer/helper_test.h
@@ -18,7 +18,7 @@
 #include <iostream>
 #include <string>
 
-#include "gmock/gmock.h"
+#include "gtest/gtest.h"
 #include "src/tint/lang/core/ir/builder.h"
 #include "src/tint/lang/core/ir/validator.h"
 #include "src/tint/lang/msl/writer/printer/printer.h"
diff --git a/src/tint/lang/msl/writer/printer/if_test.cc b/src/tint/lang/msl/writer/printer/if_test.cc
new file mode 100644
index 0000000..0f75aec
--- /dev/null
+++ b/src/tint/lang/msl/writer/printer/if_test.cc
@@ -0,0 +1,226 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/msl/writer/printer/helper_test.h"
+
+using namespace tint::number_suffixes;  // NOLINT
+
+namespace tint::msl::writer {
+namespace {
+
+TEST_F(MslPrinterTest, If) {
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        auto* if_ = b.If(true);
+        b.Append(if_->True(), [&] { b.ExitIf(if_); });
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  if (true) {
+  }
+}
+)");
+}
+
+TEST_F(MslPrinterTest, IfWithElseIf) {
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        auto* if_ = b.If(true);
+        b.Append(if_->True(), [&] { b.ExitIf(if_); });
+        b.Append(if_->False(), [&] {
+            auto* false_ = b.If(false);
+            b.Append(false_->True(), [&] { b.ExitIf(false_); });
+            b.ExitIf(if_);
+        });
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  if (true) {
+  } else {
+    if (false) {
+    }
+  }
+}
+)");
+}
+
+TEST_F(MslPrinterTest, IfWithElse) {
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        auto* if_ = b.If(true);
+        b.Append(if_->True(), [&] { b.ExitIf(if_); });
+        b.Append(if_->False(), [&] { b.Return(func); });
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  if (true) {
+  } else {
+    return;
+  }
+}
+)");
+}
+
+TEST_F(MslPrinterTest, IfBothBranchesReturn) {
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        auto* if_ = b.If(true);
+        b.Append(if_->True(), [&] { b.Return(func); });
+        b.Append(if_->False(), [&] { b.Return(func); });
+        b.Unreachable();
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  if (true) {
+    return;
+  } else {
+    return;
+  }
+  /* unreachable */
+}
+)");
+}
+
+TEST_F(MslPrinterTest, IfWithSinglePhi) {
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        auto* i = b.If(true);
+        i->SetResults(b.InstructionResult(ty.i32()));
+        b.Append(i->True(), [&] {  //
+            b.ExitIf(i, 10_i);
+        });
+        b.Append(i->False(), [&] {  //
+            b.ExitIf(i, 20_i);
+        });
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  int tint_symbol;
+  if (true) {
+    tint_symbol = 10;
+  } else {
+    tint_symbol = 20;
+  }
+}
+)");
+}
+
+TEST_F(MslPrinterTest, IfWithMultiPhi) {
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        auto* i = b.If(true);
+        i->SetResults(b.InstructionResult(ty.i32()), b.InstructionResult(ty.bool_()));
+        b.Append(i->True(), [&] {  //
+            b.ExitIf(i, 10_i, true);
+        });
+        b.Append(i->False(), [&] {  //
+            b.ExitIf(i, 20_i, false);
+        });
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  int tint_symbol;
+  bool tint_symbol_1;
+  if (true) {
+    tint_symbol = 10;
+    tint_symbol_1 = true;
+  } else {
+    tint_symbol = 20;
+    tint_symbol_1 = false;
+  }
+}
+)");
+}
+
+TEST_F(MslPrinterTest, DISABLED_IfWithMultiPhiReturn1) {
+    auto* func = b.Function("foo", ty.i32());
+    b.Append(func->Block(), [&] {
+        auto* i = b.If(true);
+        i->SetResults(b.InstructionResult(ty.i32()), b.InstructionResult(ty.bool_()));
+        b.Append(i->True(), [&] {  //
+            b.ExitIf(i, 10_i, true);
+        });
+        b.Append(i->False(), [&] {  //
+            b.ExitIf(i, 20_i, false);
+        });
+        b.Return(func, i->Result(0));
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+int foo() {
+  int tint_symbol;
+  bool tint_symbol_1;
+  if (true) {
+    tint_symbol = 10;
+    tint_symbol_1 = true;
+  } else {
+    tint_symbol = 20;
+    tint_symbol_1 = true;
+  }
+  return tint_symbol;
+}
+)");
+}
+
+TEST_F(MslPrinterTest, DISABLED_IfWithMultiPhiReturn2) {
+    auto* func = b.Function("foo", ty.bool_());
+    b.Append(func->Block(), [&] {
+        auto* i = b.If(true);
+        i->SetResults(b.InstructionResult(ty.i32()), b.InstructionResult(ty.bool_()));
+        b.Append(i->True(), [&] {  //
+            b.ExitIf(i, 10_i, true);
+        });
+        b.Append(i->False(), [&] {  //
+            b.ExitIf(i, 20_i, false);
+        });
+        b.Return(func, i->Result(1));
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+bool foo() {
+  int tint_symbol;
+  bool tint_symbol_1;
+  if (true) {
+    tint_symbol = 10;
+    tint_symbol_1 = true;
+  } else {
+    tint_symbol = 20;
+    tint_symbol_1 = true;
+  }
+  return tint_symbol_1;
+}
+)");
+}
+
+}  // namespace
+}  // namespace tint::msl::writer
diff --git a/src/tint/lang/msl/writer/printer/printer.cc b/src/tint/lang/msl/writer/printer/printer.cc
index 5294f89..dfb448c 100644
--- a/src/tint/lang/msl/writer/printer/printer.cc
+++ b/src/tint/lang/msl/writer/printer/printer.cc
@@ -17,8 +17,11 @@
 #include "src/tint/lang/core/constant/composite.h"
 #include "src/tint/lang/core/constant/splat.h"
 #include "src/tint/lang/core/ir/constant.h"
+#include "src/tint/lang/core/ir/exit_if.h"
+#include "src/tint/lang/core/ir/if.h"
 #include "src/tint/lang/core/ir/multi_in_block.h"
 #include "src/tint/lang/core/ir/return.h"
+#include "src/tint/lang/core/ir/unreachable.h"
 #include "src/tint/lang/core/ir/validator.h"
 #include "src/tint/lang/core/type/array.h"
 #include "src/tint/lang/core/type/atomic.h"
@@ -149,10 +152,6 @@
 }
 
 void Printer::EmitBlock(ir::Block* block) {
-    if (block->As<ir::MultiInBlock>()) {
-        // TODO(dsinclair): Emit variables to used by the PHIs.
-    }
-
     // TODO(dsinclair): Handle inline things
     // MarkInlinable(block);
 
@@ -164,12 +163,76 @@
 
     for (auto* inst : *block) {
         Switch(
-            inst,                                   //
-            [&](ir::Return* r) { EmitReturn(r); },  //
+            inst,                                          //
+            [&](ir::ExitIf* e) { EmitExitIf(e); },         //
+            [&](ir::If* if_) { EmitIf(if_); },             //
+            [&](ir::Return* r) { EmitReturn(r); },         //
+            [&](ir::Unreachable*) { EmitUnreachable(); },  //
             [&](Default) { TINT_ICE() << "unimplemented instruction: " << inst->TypeInfo().name; });
     }
 }
 
+void Printer::EmitIf(ir::If* if_) {
+    // TODO(dsinclair): Detect if this is a short-circuit and rebuild || and &&.
+
+    // Emit any nodes that need to be used as PHI nodes
+    for (auto* phi : if_->Results()) {
+        if (!ir_->NameOf(phi).IsValid()) {
+            ir_->SetName(phi, ir_->symbols.New());
+        }
+
+        auto out = Line();
+        EmitType(out, phi->Type());
+        out << " " << ir_->NameOf(phi).Name() << ";";
+    }
+
+    {
+        auto out = Line();
+        out << "if (";
+
+        // TODO(dsinclair): This should emit the expression instead of just assuming it's a constant
+        if (!if_->Condition()->Is<ir::Constant>()) {
+            TINT_ICE() << "if only handles constants";
+            return;
+        }
+        EmitConstant(out, if_->Condition()->As<ir::Constant>());
+        out << ") {";
+    }
+
+    {
+        ScopedIndent si(current_buffer_);
+        EmitBlockInstructions(if_->True());
+    }
+
+    if (if_->False() && !if_->False()->IsEmpty()) {
+        Line() << "} else {";
+
+        ScopedIndent si(current_buffer_);
+        EmitBlockInstructions(if_->False());
+    }
+
+    Line() << "}";
+}
+
+void Printer::EmitExitIf(ir::ExitIf* e) {
+    auto results = e->If()->Results();
+    auto args = e->Args();
+    for (size_t i = 0; i < e->Args().Length(); ++i) {
+        auto* phi = results[i];
+        auto* val = args[i];
+
+        if (!val->Is<ir::Constant>()) {
+            TINT_ICE() << "exit-if only handles constants";
+            return;
+        }
+
+        auto out = Line();
+        out << ir_->NameOf(phi).Name() << " = "; /* << Expr(val); */
+        EmitConstant(out, val->As<ir::Constant>());
+        out << ";";
+    }
+}
+
 void Printer::EmitReturn(ir::Return* r) {
     // If this return has no arguments and the current block is for the function which is being
     // returned, skip the return.
@@ -194,6 +257,10 @@
     out << ";";
 }
 
+void Printer::EmitUnreachable() {
+    Line() << "/* unreachable */";
+}
+
 void Printer::EmitAddressSpace(StringStream& out, builtin::AddressSpace sc) {
     switch (sc) {
         case builtin::AddressSpace::kFunction:
diff --git a/src/tint/lang/msl/writer/printer/printer.h b/src/tint/lang/msl/writer/printer/printer.h
index 59c87ac..d740b19 100644
--- a/src/tint/lang/msl/writer/printer/printer.h
+++ b/src/tint/lang/msl/writer/printer/printer.h
@@ -27,7 +27,10 @@
 
 // Forward declarations
 namespace tint::ir {
+class ExitIf;
+class If;
 class Return;
+class Unreachable;
 }  // namespace tint::ir
 
 namespace tint::msl::writer {
@@ -57,9 +60,18 @@
     /// @param block the block with the instructions to emit
     void EmitBlockInstructions(ir::Block* block);
 
+    /// Emit an if instruction
+    /// @param if_ the if instruction
+    void EmitIf(ir::If* if_);
+    /// Emit an exit-if instruction
+    /// @param e the exit-if instruction
+    void EmitExitIf(ir::ExitIf* e);
+
     /// Emit a return instruction
     /// @param r the return instruction
     void EmitReturn(ir::Return* r);
+    /// Emit an unreachable instruction
+    void EmitUnreachable();
 
     /// Emit a type
     /// @param out the stream to emit too
diff --git a/src/tint/lang/msl/writer/printer/return_test.cc b/src/tint/lang/msl/writer/printer/return_test.cc
index ae8fdab..af48ef4 100644
--- a/src/tint/lang/msl/writer/printer/return_test.cc
+++ b/src/tint/lang/msl/writer/printer/return_test.cc
@@ -19,12 +19,12 @@
 namespace tint::msl::writer {
 namespace {
 
-// TODO(dsinclair): Requires if emission in MSL generator
-TEST_F(MslPrinterTest, DISABLED_Return) {
+TEST_F(MslPrinterTest, Return) {
     auto* func = b.Function("foo", ty.void_());
     b.Append(func->Block(), [&] {
         auto* if_ = b.If(true);
         b.Append(if_->True(), [&] { b.Return(func); });
+        b.Return(func);
     });
 
     ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();