[spirv-writer] Avoid branch after dead if/then if/else

Bug: tint:64
Change-Id: I008c449ca634c6410055a65927199fda2d7bbb06
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/20720
Reviewed-by: dan sinclair <dsinclair@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 78d302c..2a6ee58 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -94,6 +94,8 @@
   }
 
   auto* last = stmts.back().get();
+  // TODO(dneto): Conditional break and conditional continue should return
+  // false.
   return last->IsBreak() || last->IsContinue() || last->IsReturn() ||
          last->IsKill() || last->IsFallthrough();
 }
@@ -1027,9 +1029,10 @@
   if (!GenerateStatementList(true_body)) {
     return false;
   }
-  // TODO(dsinclair): The branch should be optional based on how the
-  // StatementList ended ...
-  push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
+  // We only branch if the last element of the body didn't already branch.
+  if (!LastIsTerminator(true_body)) {
+    push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
+  }
 
   // Start the false block if needed
   if (false_block_id != merge_block_id) {
@@ -1041,14 +1044,13 @@
       if (!GenerateStatementList(else_stmt->body())) {
         return false;
       }
-      // TODO(dsinclair): The branch should be optional based on how the
-      // StatementList ended ...
-      push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
     } else {
       if (!GenerateConditionalBlock(else_stmt->condition(), else_stmt->body(),
                                     cur_else_idx + 1, else_stmts)) {
         return false;
       }
+    }
+    if (!LastIsTerminator(else_stmt->body())) {
       push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
     }
   }
diff --git a/src/writer/spirv/builder_if_test.cc b/src/writer/spirv/builder_if_test.cc
index bd9ada5..dbe1dc4 100644
--- a/src/writer/spirv/builder_if_test.cc
+++ b/src/writer/spirv/builder_if_test.cc
@@ -17,11 +17,16 @@
 #include "gtest/gtest.h"
 #include "src/ast/assignment_statement.h"
 #include "src/ast/bool_literal.h"
+#include "src/ast/break_statement.h"
+#include "src/ast/continue_statement.h"
 #include "src/ast/else_statement.h"
 #include "src/ast/identifier_expression.h"
 #include "src/ast/if_statement.h"
 #include "src/ast/int_literal.h"
+#include "src/ast/loop_statement.h"
+#include "src/ast/return_statement.h"
 #include "src/ast/scalar_constructor_expression.h"
+#include "src/ast/statement_condition.h"
 #include "src/ast/type/bool_type.h"
 #include "src/ast/type/i32_type.h"
 #include "src/context.h"
@@ -367,22 +372,313 @@
 )");
 }
 
-TEST_F(BuilderTest, DISABLED_If_WithBreak) {
-  // if (a) {
-  //   break;
+TEST_F(BuilderTest, If_WithBreak) {
+  ast::type::BoolType bool_type;
+  // loop {
+  //   if (true) {
+  //     break;
+  //   }
   // }
+  auto cond = std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::BoolLiteral>(&bool_type, true));
+
+  ast::StatementList if_body;
+  if_body.push_back(std::make_unique<ast::BreakStatement>());
+
+  auto if_stmt =
+      std::make_unique<ast::IfStatement>(std::move(cond), std::move(if_body));
+
+  ast::StatementList loop_body;
+  loop_body.push_back(std::move(if_stmt));
+
+  ast::LoopStatement expr(std::move(loop_body), {});
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+  Builder b(&mod);
+  b.push_function(Function{});
+
+  EXPECT_TRUE(b.GenerateLoopStatement(&expr)) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeBool
+%6 = OpConstantTrue %5
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(OpBranch %1
+%1 = OpLabel
+OpLoopMerge %2 %3 None
+OpBranch %4
+%4 = OpLabel
+OpSelectionMerge %7 None
+OpBranchConditional %6 %8 %7
+%8 = OpLabel
+OpBranch %2
+%7 = OpLabel
+OpBranch %3
+%3 = OpLabel
+OpBranch %1
+%2 = OpLabel
+)");
 }
 
-TEST_F(BuilderTest, DISABLED_If_WithContinue) {
-  // if (a) {
-  //   continue;
+TEST_F(BuilderTest, If_WithElseBreak) {
+  ast::type::BoolType bool_type;
+  // loop {
+  //   if (true) {
+  //   } else {
+  //     break;
+  //   }
   // }
+  auto cond = std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::BoolLiteral>(&bool_type, true));
+
+  ast::StatementList if_body;
+  ast::StatementList else_body;
+  else_body.push_back(std::make_unique<ast::BreakStatement>());
+
+  ast::ElseStatementList else_stmts;
+  else_stmts.push_back(
+      std::make_unique<ast::ElseStatement>(std::move(else_body)));
+
+  auto if_stmt =
+      std::make_unique<ast::IfStatement>(std::move(cond), std::move(if_body));
+  if_stmt->set_else_statements(std::move(else_stmts));
+
+  ast::StatementList loop_body;
+  loop_body.push_back(std::move(if_stmt));
+
+  ast::LoopStatement expr(std::move(loop_body), {});
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+  Builder b(&mod);
+  b.push_function(Function{});
+
+  EXPECT_TRUE(b.GenerateLoopStatement(&expr)) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeBool
+%6 = OpConstantTrue %5
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(OpBranch %1
+%1 = OpLabel
+OpLoopMerge %2 %3 None
+OpBranch %4
+%4 = OpLabel
+OpSelectionMerge %7 None
+OpBranchConditional %6 %8 %9
+%8 = OpLabel
+OpBranch %7
+%9 = OpLabel
+OpBranch %2
+%7 = OpLabel
+OpBranch %3
+%3 = OpLabel
+OpBranch %1
+%2 = OpLabel
+)");
 }
 
-TEST_F(BuilderTest, DISABLED_IF_WithReturn) {
-  // if (a) {
+// This is blocked on implementing conditional break
+TEST_F(BuilderTest, DISABLED_If_WithConditionalBreak) {
+  ast::type::BoolType bool_type;
+  // loop {
+  //   if (true) {
+  //     break if (false);
+  //   }
+  // }
+  auto cond_true = std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::BoolLiteral>(&bool_type, true));
+  auto cond_false = std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::BoolLiteral>(&bool_type, false));
+
+  ast::StatementList if_body;
+  if_body.push_back(std::make_unique<ast::BreakStatement>(
+      ast::StatementCondition::kIf, std::move(cond_false)));
+
+  auto if_stmt = std::make_unique<ast::IfStatement>(std::move(cond_true),
+                                                    std::move(if_body));
+
+  ast::StatementList loop_body;
+  loop_body.push_back(std::move(if_stmt));
+
+  ast::LoopStatement expr(std::move(loop_body), {});
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+  Builder b(&mod);
+  b.push_function(Function{});
+
+  EXPECT_TRUE(b.GenerateLoopStatement(&expr)) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeBool
+%6 = OpConstantTrue %5
+%7 = OpConstantFalse %5
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(OpBranch %1
+%1 = OpLabel
+OpLoopMerge %2 %3 None
+OpBranch %4
+%4 = OpLabel
+OpSelectionMerge %8 None
+OpBranchConditional %6 %9 %8
+%9 = OpLabel
+OpBranchConditional %7 %2 %8
+%8 = OpLabel
+OpBranch %3
+%3 = OpLabel
+OpBranch %1
+%2 = OpLabel
+)");
+}
+
+TEST_F(BuilderTest, DISABLED_If_WithElseConditionalBreak) {
+  FAIL();
+}
+
+TEST_F(BuilderTest, If_WithContinue) {
+  ast::type::BoolType bool_type;
+  // loop {
+  //   if (true) {
+  //     continue;
+  //   }
+  // }
+  auto cond = std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::BoolLiteral>(&bool_type, true));
+
+  ast::StatementList if_body;
+  if_body.push_back(std::make_unique<ast::ContinueStatement>());
+
+  auto if_stmt =
+      std::make_unique<ast::IfStatement>(std::move(cond), std::move(if_body));
+
+  ast::StatementList loop_body;
+  loop_body.push_back(std::move(if_stmt));
+
+  ast::LoopStatement expr(std::move(loop_body), {});
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+  Builder b(&mod);
+  b.push_function(Function{});
+
+  EXPECT_TRUE(b.GenerateLoopStatement(&expr)) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeBool
+%6 = OpConstantTrue %5
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(OpBranch %1
+%1 = OpLabel
+OpLoopMerge %2 %3 None
+OpBranch %4
+%4 = OpLabel
+OpSelectionMerge %7 None
+OpBranchConditional %6 %8 %7
+%8 = OpLabel
+OpBranch %3
+%7 = OpLabel
+OpBranch %3
+%3 = OpLabel
+OpBranch %1
+%2 = OpLabel
+)");
+}
+
+TEST_F(BuilderTest, DISABLED_If_WithConditionalContinue) {
+  FAIL();
+}
+TEST_F(BuilderTest, DISABLED_If_WithElseContinue) {
+  FAIL();
+}
+TEST_F(BuilderTest, DISABLED_If_WithElseConditionalContinue) {
+  FAIL();
+}
+
+TEST_F(BuilderTest, If_WithReturn) {
+  ast::type::BoolType bool_type;
+  // if (true) {
   //   return;
   // }
+  auto cond = std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::BoolLiteral>(&bool_type, true));
+
+  ast::StatementList if_body;
+  if_body.push_back(std::make_unique<ast::ReturnStatement>());
+
+  ast::IfStatement expr(std::move(cond), std::move(if_body));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+  Builder b(&mod);
+  b.push_function(Function{});
+
+  EXPECT_TRUE(b.GenerateIfStatement(&expr)) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeBool
+%2 = OpConstantTrue %1
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(OpSelectionMerge %3 None
+OpBranchConditional %2 %4 %3
+%4 = OpLabel
+OpReturn
+%3 = OpLabel
+)");
+}
+
+TEST_F(BuilderTest, If_WithReturnValue) {
+  ast::type::BoolType bool_type;
+  // if (true) {
+  //   return false;
+  // }
+  auto cond = std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::BoolLiteral>(&bool_type, true));
+  auto cond2 = std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::BoolLiteral>(&bool_type, false));
+
+  ast::StatementList if_body;
+  if_body.push_back(std::make_unique<ast::ReturnStatement>(std::move(cond2)));
+
+  ast::IfStatement expr(std::move(cond), std::move(if_body));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+  Builder b(&mod);
+  b.push_function(Function{});
+
+  EXPECT_TRUE(b.GenerateIfStatement(&expr)) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeBool
+%2 = OpConstantTrue %1
+%5 = OpConstantFalse %1
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(OpSelectionMerge %3 None
+OpBranchConditional %2 %4 %3
+%4 = OpLabel
+OpReturnValue %5
+%3 = OpLabel
+)");
 }
 
 }  // namespace