[ir][msl] Add support for `if` statements
Adds support for emitting `if`, `if-else` and `if-elseif` statements
from the MSL IR generator.
Bug: tint:1967
Change-Id: I8c4ff5bfe5a9505ca1f1c7ced4ee71c1d6d5c108
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/144265
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 371d351..5abb8d1 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -2329,6 +2329,7 @@
"lang/msl/writer/printer/constant_test.cc",
"lang/msl/writer/printer/function_test.cc",
"lang/msl/writer/printer/helper_test.h",
+ "lang/msl/writer/printer/if_test.cc",
"lang/msl/writer/printer/return_test.cc",
"lang/msl/writer/printer/type_test.cc",
]
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index f1dbdaf..92b5621 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -1556,6 +1556,7 @@
lang/msl/writer/printer/constant_test.cc
lang/msl/writer/printer/function_test.cc
lang/msl/writer/printer/helper_test.h
+ lang/msl/writer/printer/if_test.cc
lang/msl/writer/printer/return_test.cc
lang/msl/writer/printer/type_test.cc
)
diff --git a/src/tint/lang/msl/writer/printer/helper_test.h b/src/tint/lang/msl/writer/printer/helper_test.h
index d6c9f8b..152cf3b 100644
--- a/src/tint/lang/msl/writer/printer/helper_test.h
+++ b/src/tint/lang/msl/writer/printer/helper_test.h
@@ -18,7 +18,7 @@
#include <iostream>
#include <string>
-#include "gmock/gmock.h"
+#include "gtest/gtest.h"
#include "src/tint/lang/core/ir/builder.h"
#include "src/tint/lang/core/ir/validator.h"
#include "src/tint/lang/msl/writer/printer/printer.h"
diff --git a/src/tint/lang/msl/writer/printer/if_test.cc b/src/tint/lang/msl/writer/printer/if_test.cc
new file mode 100644
index 0000000..0f75aec
--- /dev/null
+++ b/src/tint/lang/msl/writer/printer/if_test.cc
@@ -0,0 +1,226 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/msl/writer/printer/helper_test.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::msl::writer {
+namespace {
+
+TEST_F(MslPrinterTest, If) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ auto* if_ = b.If(true);
+ b.Append(if_->True(), [&] { b.ExitIf(if_); });
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ if (true) {
+ }
+}
+)");
+}
+
+TEST_F(MslPrinterTest, IfWithElseIf) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ auto* if_ = b.If(true);
+ b.Append(if_->True(), [&] { b.ExitIf(if_); });
+ b.Append(if_->False(), [&] {
+ auto* false_ = b.If(false);
+ b.Append(false_->True(), [&] { b.ExitIf(false_); });
+ b.ExitIf(if_);
+ });
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ if (true) {
+ } else {
+ if (false) {
+ }
+ }
+}
+)");
+}
+
+TEST_F(MslPrinterTest, IfWithElse) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ auto* if_ = b.If(true);
+ b.Append(if_->True(), [&] { b.ExitIf(if_); });
+ b.Append(if_->False(), [&] { b.Return(func); });
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ if (true) {
+ } else {
+ return;
+ }
+}
+)");
+}
+
+TEST_F(MslPrinterTest, IfBothBranchesReturn) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ auto* if_ = b.If(true);
+ b.Append(if_->True(), [&] { b.Return(func); });
+ b.Append(if_->False(), [&] { b.Return(func); });
+ b.Unreachable();
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ if (true) {
+ return;
+ } else {
+ return;
+ }
+ /* unreachable */
+}
+)");
+}
+
+TEST_F(MslPrinterTest, IfWithSinglePhi) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ auto* i = b.If(true);
+ i->SetResults(b.InstructionResult(ty.i32()));
+ b.Append(i->True(), [&] { //
+ b.ExitIf(i, 10_i);
+ });
+ b.Append(i->False(), [&] { //
+ b.ExitIf(i, 20_i);
+ });
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ int tint_symbol;
+ if (true) {
+ tint_symbol = 10;
+ } else {
+ tint_symbol = 20;
+ }
+}
+)");
+}
+
+TEST_F(MslPrinterTest, IfWithMultiPhi) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ auto* i = b.If(true);
+ i->SetResults(b.InstructionResult(ty.i32()), b.InstructionResult(ty.bool_()));
+ b.Append(i->True(), [&] { //
+ b.ExitIf(i, 10_i, true);
+ });
+ b.Append(i->False(), [&] { //
+ b.ExitIf(i, 20_i, false);
+ });
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+ int tint_symbol;
+ bool tint_symbol_1;
+ if (true) {
+ tint_symbol = 10;
+ tint_symbol_1 = true;
+ } else {
+ tint_symbol = 20;
+ tint_symbol_1 = false;
+ }
+}
+)");
+}
+
+TEST_F(MslPrinterTest, DISABLED_IfWithMultiPhiReturn1) {
+ auto* func = b.Function("foo", ty.i32());
+ b.Append(func->Block(), [&] {
+ auto* i = b.If(true);
+ i->SetResults(b.InstructionResult(ty.i32()), b.InstructionResult(ty.bool_()));
+ b.Append(i->True(), [&] { //
+ b.ExitIf(i, 10_i, true);
+ });
+ b.Append(i->False(), [&] { //
+ b.ExitIf(i, 20_i, false);
+ });
+ b.Return(func, i->Result(0));
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+int foo() {
+ int tint_symbol;
+ bool tint_symbol_1;
+ if (true) {
+ tint_symbol = 10;
+ tint_symbol_1 = true;
+ } else {
+ tint_symbol = 20;
+ tint_symbol_1 = true;
+ }
+ return tint_symbol;
+}
+)");
+}
+
+TEST_F(MslPrinterTest, DISABLED_IfWithMultiPhiReturn2) {
+ auto* func = b.Function("foo", ty.bool_());
+ b.Append(func->Block(), [&] {
+ auto* i = b.If(true);
+ i->SetResults(b.InstructionResult(ty.i32()), b.InstructionResult(ty.bool_()));
+ b.Append(i->True(), [&] { //
+ b.ExitIf(i, 10_i, true);
+ });
+ b.Append(i->False(), [&] { //
+ b.ExitIf(i, 20_i, false);
+ });
+ b.Return(func, i->Result(1));
+ });
+
+ ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+ EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+bool foo() {
+ int tint_symbol;
+ bool tint_symbol_1;
+ if (true) {
+ tint_symbol = 10;
+ tint_symbol_1 = true;
+ } else {
+ tint_symbol = 20;
+ tint_symbol_1 = true;
+ }
+ return tint_symbol_1;
+}
+)");
+}
+
+} // namespace
+} // namespace tint::msl::writer
diff --git a/src/tint/lang/msl/writer/printer/printer.cc b/src/tint/lang/msl/writer/printer/printer.cc
index 5294f89..dfb448c 100644
--- a/src/tint/lang/msl/writer/printer/printer.cc
+++ b/src/tint/lang/msl/writer/printer/printer.cc
@@ -17,8 +17,11 @@
#include "src/tint/lang/core/constant/composite.h"
#include "src/tint/lang/core/constant/splat.h"
#include "src/tint/lang/core/ir/constant.h"
+#include "src/tint/lang/core/ir/exit_if.h"
+#include "src/tint/lang/core/ir/if.h"
#include "src/tint/lang/core/ir/multi_in_block.h"
#include "src/tint/lang/core/ir/return.h"
+#include "src/tint/lang/core/ir/unreachable.h"
#include "src/tint/lang/core/ir/validator.h"
#include "src/tint/lang/core/type/array.h"
#include "src/tint/lang/core/type/atomic.h"
@@ -149,10 +152,6 @@
}
void Printer::EmitBlock(ir::Block* block) {
- if (block->As<ir::MultiInBlock>()) {
- // TODO(dsinclair): Emit variables to used by the PHIs.
- }
-
// TODO(dsinclair): Handle inline things
// MarkInlinable(block);
@@ -164,12 +163,76 @@
for (auto* inst : *block) {
Switch(
- inst, //
- [&](ir::Return* r) { EmitReturn(r); }, //
+ inst, //
+ [&](ir::ExitIf* e) { EmitExitIf(e); }, //
+ [&](ir::If* if_) { EmitIf(if_); }, //
+ [&](ir::Return* r) { EmitReturn(r); }, //
+ [&](ir::Unreachable*) { EmitUnreachable(); }, //
[&](Default) { TINT_ICE() << "unimplemented instruction: " << inst->TypeInfo().name; });
}
}
+void Printer::EmitIf(ir::If* if_) {
+ // TODO(dsinclair): Detect if this is a short-circuit and rebuild || and &&.
+
+ // Emit any nodes that need to be used as PHI nodes
+ for (auto* phi : if_->Results()) {
+ if (!ir_->NameOf(phi).IsValid()) {
+ ir_->SetName(phi, ir_->symbols.New());
+ }
+
+ auto out = Line();
+ EmitType(out, phi->Type());
+ out << " " << ir_->NameOf(phi).Name() << ";";
+ }
+
+ {
+ auto out = Line();
+ out << "if (";
+
+ // TODO(dsinclair): This should emit the expression instead of just assuming it's a constant
+ if (!if_->Condition()->Is<ir::Constant>()) {
+ TINT_ICE() << "if only handles constants";
+ return;
+ }
+ EmitConstant(out, if_->Condition()->As<ir::Constant>());
+ out << ") {";
+ }
+
+ {
+ ScopedIndent si(current_buffer_);
+ EmitBlockInstructions(if_->True());
+ }
+
+ if (if_->False() && !if_->False()->IsEmpty()) {
+ Line() << "} else {";
+
+ ScopedIndent si(current_buffer_);
+ EmitBlockInstructions(if_->False());
+ }
+
+ Line() << "}";
+}
+
+void Printer::EmitExitIf(ir::ExitIf* e) {
+ auto results = e->If()->Results();
+ auto args = e->Args();
+ for (size_t i = 0; i < e->Args().Length(); ++i) {
+ auto* phi = results[i];
+ auto* val = args[i];
+
+ if (!val->Is<ir::Constant>()) {
+ TINT_ICE() << "exit-if only handles constants";
+ return;
+ }
+
+ auto out = Line();
+ out << ir_->NameOf(phi).Name() << " = "; /* << Expr(val); */
+ EmitConstant(out, val->As<ir::Constant>());
+ out << ";";
+ }
+}
+
void Printer::EmitReturn(ir::Return* r) {
// If this return has no arguments and the current block is for the function which is being
// returned, skip the return.
@@ -194,6 +257,10 @@
out << ";";
}
+void Printer::EmitUnreachable() {
+ Line() << "/* unreachable */";
+}
+
void Printer::EmitAddressSpace(StringStream& out, builtin::AddressSpace sc) {
switch (sc) {
case builtin::AddressSpace::kFunction:
diff --git a/src/tint/lang/msl/writer/printer/printer.h b/src/tint/lang/msl/writer/printer/printer.h
index 59c87ac..d740b19 100644
--- a/src/tint/lang/msl/writer/printer/printer.h
+++ b/src/tint/lang/msl/writer/printer/printer.h
@@ -27,7 +27,10 @@
// Forward declarations
namespace tint::ir {
+class ExitIf;
+class If;
class Return;
+class Unreachable;
} // namespace tint::ir
namespace tint::msl::writer {
@@ -57,9 +60,18 @@
/// @param block the block with the instructions to emit
void EmitBlockInstructions(ir::Block* block);
+ /// Emit an if instruction
+ /// @param if_ the if instruction
+ void EmitIf(ir::If* if_);
+ /// Emit an exit-if instruction
+ /// @param e the exit-if instruction
+ void EmitExitIf(ir::ExitIf* e);
+
/// Emit a return instruction
/// @param r the return instruction
void EmitReturn(ir::Return* r);
+ /// Emit an unreachable instruction
+ void EmitUnreachable();
/// Emit a type
/// @param out the stream to emit too
diff --git a/src/tint/lang/msl/writer/printer/return_test.cc b/src/tint/lang/msl/writer/printer/return_test.cc
index ae8fdab..af48ef4 100644
--- a/src/tint/lang/msl/writer/printer/return_test.cc
+++ b/src/tint/lang/msl/writer/printer/return_test.cc
@@ -19,12 +19,12 @@
namespace tint::msl::writer {
namespace {
-// TODO(dsinclair): Requires if emission in MSL generator
-TEST_F(MslPrinterTest, DISABLED_Return) {
+TEST_F(MslPrinterTest, Return) {
auto* func = b.Function("foo", ty.void_());
b.Append(func->Block(), [&] {
auto* if_ = b.If(true);
b.Append(if_->True(), [&] { b.Return(func); });
+ b.Return(func);
});
ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();