spirv-writer: Fix termination of basic blocks

There are a few places where a branch or return is created,
conditionally on whether a terminator was the last thing seen.
The goal is to generate a SPIR-V basic block terminator exactly
when needed, and to avoid generating a branch or return immediately
after a prior terminator.

Previously, the decision was based on the last thing seen in the AST.
But we should instead check the emitted SPIR-V instead.

This fixes cases such as a break or return inside an else-if.
That's because an if/elseif is actually a selection inside a selection.
Looking at the AST only works when trying to terminate the *inside*
selection.  In the outer recursive call, the last AST node is
no longer a terminator, and we would skip generating the branch
to the merge block.

Fixed: tint:1315
Change-Id: I6b886ce85d1d681f2063997e469e0c1b4e5973a2
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/73480
Kokoro: Kokoro <noreply+kokoro@google.com>
Auto-Submit: David Neto <dneto@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 04334a2..1f6246e 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -96,23 +96,6 @@
   return !stmts->Empty() && stmts->Last()->Is<ast::FallthroughStatement>();
 }
 
-// A terminator is anything which will cause a SPIR-V terminator to be emitted.
-// This means things like breaks, fallthroughs and continues which all emit an
-// OpBranch or return for the OpReturn emission.
-bool LastIsTerminator(const ast::BlockStatement* stmts) {
-  if (IsAnyOf<ast::BreakStatement, ast::ContinueStatement,
-              ast::DiscardStatement, ast::ReturnStatement,
-              ast::FallthroughStatement>(stmts->Last())) {
-    return true;
-  }
-
-  if (auto* block = As<ast::BlockStatement>(stmts->Last())) {
-    return LastIsTerminator(block);
-  }
-
-  return false;
-}
-
 /// Returns the matrix type that is `type` or that is wrapped by
 /// one or more levels of an arrays inside of `type`.
 /// @param type the given type, which must not be null
@@ -650,7 +633,7 @@
     }
   }
 
-  if (!LastIsTerminator(func_ast->body)) {
+  if (InsideBasicBlock()) {
     if (func->ReturnType()->Is<sem::Void>()) {
       push_function_inst(spv::Op::OpReturn, {});
     } else {
@@ -3500,7 +3483,7 @@
     return false;
   }
   // We only branch if the last element of the body didn't already branch.
-  if (!LastIsTerminator(true_body)) {
+  if (InsideBasicBlock()) {
     if (!push_function_inst(spv::Op::OpBranch,
                             {Operand::Int(merge_block_id)})) {
       return false;
@@ -3525,7 +3508,7 @@
         return false;
       }
     }
-    if (!LastIsTerminator(else_stmt->body)) {
+    if (InsideBasicBlock()) {
       if (!push_function_inst(spv::Op::OpBranch,
                               {Operand::Int(merge_block_id)})) {
         return false;
@@ -3673,7 +3656,7 @@
                               {Operand::Int(case_ids[i + 1])})) {
         return false;
       }
-    } else if (!LastIsTerminator(item->body)) {
+    } else if (InsideBasicBlock()) {
       if (!push_function_inst(spv::Op::OpBranch,
                               {Operand::Int(merge_block_id)})) {
         return false;
@@ -3765,7 +3748,7 @@
     }
 
     // We only branch if the last element of the body didn't already branch.
-    if (!LastIsTerminator(stmt->body)) {
+    if (InsideBasicBlock()) {
       if (!push_function_inst(spv::Op::OpBranch,
                               {Operand::Int(continue_block_id)})) {
         return false;
@@ -4414,6 +4397,34 @@
   return true;
 }
 
+bool Builder::InsideBasicBlock() const {
+  if (functions_.empty()) {
+    return false;
+  }
+  const auto& instructions = functions_.back().instructions();
+  if (instructions.empty()) {
+    // The Function object does not explicitly represent its entry block
+    // label.  So return *true* because an empty list means the only
+    // thing in the function is that entry block label.
+    return true;
+  }
+  const auto& inst = instructions.back();
+  switch (inst.opcode()) {
+    case spv::Op::OpBranch:
+    case spv::Op::OpBranchConditional:
+    case spv::Op::OpSwitch:
+    case spv::Op::OpReturn:
+    case spv::Op::OpReturnValue:
+    case spv::Op::OpUnreachable:
+    case spv::Op::OpKill:
+    case spv::Op::OpTerminateInvocation:
+      return false;
+    default:
+      break;
+  }
+  return true;
+}
+
 Builder::ContinuingInfo::ContinuingInfo(
     const ast::Statement* the_last_statement,
     uint32_t loop_id,
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index d2b5237..fdfc3b3 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -216,6 +216,10 @@
     functions_.back().push_var(operands);
   }
 
+  /// @returns true if the current instruction insertion point is
+  /// inside a basic block.
+  bool InsideBasicBlock() const;
+
   /// Converts a storage class to a SPIR-V storage class.
   /// @param klass the storage class to convert
   /// @returns the SPIR-V storage class or SpvStorageClassMax on error.
diff --git a/src/writer/spirv/builder_if_test.cc b/src/writer/spirv/builder_if_test.cc
index 6273bee..2fda0e1 100644
--- a/src/writer/spirv/builder_if_test.cc
+++ b/src/writer/spirv/builder_if_test.cc
@@ -603,6 +603,91 @@
 )");
 }
 
+TEST_F(BuilderTest, If_ElseIf_WithReturn) {
+  // crbug.com/tint/1315
+  // if (false) {
+  // } else if (true) {
+  //   return;
+  // }
+
+  auto* if_stmt = If(Expr(false), Block(),
+                     ast::ElseStatementList{Else(Expr(true), Block(Return()))});
+  auto* fn = Func("f", {}, ty.void_(), {if_stmt});
+
+  spirv::Builder& b = Build();
+
+  EXPECT_TRUE(b.GenerateFunction(fn)) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeVoid
+%1 = OpTypeFunction %2
+%5 = OpTypeBool
+%6 = OpConstantFalse %5
+%10 = OpConstantTrue %5
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(OpSelectionMerge %7 None
+OpBranchConditional %6 %8 %9
+%8 = OpLabel
+OpBranch %7
+%9 = OpLabel
+OpSelectionMerge %11 None
+OpBranchConditional %10 %12 %11
+%12 = OpLabel
+OpReturn
+%11 = OpLabel
+OpBranch %7
+%7 = OpLabel
+OpReturn
+)");
+}
+
+TEST_F(BuilderTest, Loop_If_ElseIf_WithBreak) {
+  // crbug.com/tint/1315
+  // loop {
+  //   if (false) {
+  //   } else if (true) {
+  //     break;
+  //   }
+  // }
+
+  auto* if_stmt = If(Expr(false), Block(),
+                     ast::ElseStatementList{Else(Expr(true), Block(Break()))});
+  auto* fn = Func("f", {}, ty.void_(), {Loop(Block(if_stmt))});
+
+  spirv::Builder& b = Build();
+
+  EXPECT_TRUE(b.GenerateFunction(fn)) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeVoid
+%1 = OpTypeFunction %2
+%9 = OpTypeBool
+%10 = OpConstantFalse %9
+%14 = OpConstantTrue %9
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(OpBranch %5
+%5 = OpLabel
+OpLoopMerge %6 %7 None
+OpBranch %8
+%8 = OpLabel
+OpSelectionMerge %11 None
+OpBranchConditional %10 %12 %13
+%12 = OpLabel
+OpBranch %11
+%13 = OpLabel
+OpSelectionMerge %15 None
+OpBranchConditional %14 %16 %15
+%16 = OpLabel
+OpBranch %6
+%15 = OpLabel
+OpBranch %11
+%11 = OpLabel
+OpBranch %7
+%7 = OpLabel
+OpBranch %5
+%6 = OpLabel
+OpReturn
+)");
+}
+
 }  // namespace
 }  // namespace spirv
 }  // namespace writer
diff --git a/src/writer/spirv/function.h b/src/writer/spirv/function.h
index 926f33f..df747ec 100644
--- a/src/writer/spirv/function.h
+++ b/src/writer/spirv/function.h
@@ -32,7 +32,7 @@
 
   /// Constructor
   /// @param declaration the function declaration
-  /// @param label_op the operand for the initial function label
+  /// @param label_op the operand for function's entry block label
   /// @param params the function parameters
   Function(const Instruction& declaration,
            const Operand& label_op,
@@ -49,7 +49,7 @@
   /// @returns the declaration
   const Instruction& declaration() const { return declaration_; }
 
-  /// @returns the function label id
+  /// @returns the label ID for the function entry block
   uint32_t label_id() const { return label_op_.to_i(); }
 
   /// Adds an instruction to the instruction list