spirv-writer: Fix termination of basic blocks
There are a few places where a branch or return is created,
conditionally on whether a terminator was the last thing seen.
The goal is to generate a SPIR-V basic block terminator exactly
when needed, and to avoid generating a branch or return immediately
after a prior terminator.
Previously, the decision was based on the last thing seen in the AST.
But we should instead check the emitted SPIR-V instead.
This fixes cases such as a break or return inside an else-if.
That's because an if/elseif is actually a selection inside a selection.
Looking at the AST only works when trying to terminate the *inside*
selection. In the outer recursive call, the last AST node is
no longer a terminator, and we would skip generating the branch
to the merge block.
Fixed: tint:1315
Change-Id: I6b886ce85d1d681f2063997e469e0c1b4e5973a2
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/73480
Kokoro: Kokoro <noreply+kokoro@google.com>
Auto-Submit: David Neto <dneto@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 04334a2..1f6246e 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -96,23 +96,6 @@
return !stmts->Empty() && stmts->Last()->Is<ast::FallthroughStatement>();
}
-// A terminator is anything which will cause a SPIR-V terminator to be emitted.
-// This means things like breaks, fallthroughs and continues which all emit an
-// OpBranch or return for the OpReturn emission.
-bool LastIsTerminator(const ast::BlockStatement* stmts) {
- if (IsAnyOf<ast::BreakStatement, ast::ContinueStatement,
- ast::DiscardStatement, ast::ReturnStatement,
- ast::FallthroughStatement>(stmts->Last())) {
- return true;
- }
-
- if (auto* block = As<ast::BlockStatement>(stmts->Last())) {
- return LastIsTerminator(block);
- }
-
- return false;
-}
-
/// Returns the matrix type that is `type` or that is wrapped by
/// one or more levels of an arrays inside of `type`.
/// @param type the given type, which must not be null
@@ -650,7 +633,7 @@
}
}
- if (!LastIsTerminator(func_ast->body)) {
+ if (InsideBasicBlock()) {
if (func->ReturnType()->Is<sem::Void>()) {
push_function_inst(spv::Op::OpReturn, {});
} else {
@@ -3500,7 +3483,7 @@
return false;
}
// We only branch if the last element of the body didn't already branch.
- if (!LastIsTerminator(true_body)) {
+ if (InsideBasicBlock()) {
if (!push_function_inst(spv::Op::OpBranch,
{Operand::Int(merge_block_id)})) {
return false;
@@ -3525,7 +3508,7 @@
return false;
}
}
- if (!LastIsTerminator(else_stmt->body)) {
+ if (InsideBasicBlock()) {
if (!push_function_inst(spv::Op::OpBranch,
{Operand::Int(merge_block_id)})) {
return false;
@@ -3673,7 +3656,7 @@
{Operand::Int(case_ids[i + 1])})) {
return false;
}
- } else if (!LastIsTerminator(item->body)) {
+ } else if (InsideBasicBlock()) {
if (!push_function_inst(spv::Op::OpBranch,
{Operand::Int(merge_block_id)})) {
return false;
@@ -3765,7 +3748,7 @@
}
// We only branch if the last element of the body didn't already branch.
- if (!LastIsTerminator(stmt->body)) {
+ if (InsideBasicBlock()) {
if (!push_function_inst(spv::Op::OpBranch,
{Operand::Int(continue_block_id)})) {
return false;
@@ -4414,6 +4397,34 @@
return true;
}
+bool Builder::InsideBasicBlock() const {
+ if (functions_.empty()) {
+ return false;
+ }
+ const auto& instructions = functions_.back().instructions();
+ if (instructions.empty()) {
+ // The Function object does not explicitly represent its entry block
+ // label. So return *true* because an empty list means the only
+ // thing in the function is that entry block label.
+ return true;
+ }
+ const auto& inst = instructions.back();
+ switch (inst.opcode()) {
+ case spv::Op::OpBranch:
+ case spv::Op::OpBranchConditional:
+ case spv::Op::OpSwitch:
+ case spv::Op::OpReturn:
+ case spv::Op::OpReturnValue:
+ case spv::Op::OpUnreachable:
+ case spv::Op::OpKill:
+ case spv::Op::OpTerminateInvocation:
+ return false;
+ default:
+ break;
+ }
+ return true;
+}
+
Builder::ContinuingInfo::ContinuingInfo(
const ast::Statement* the_last_statement,
uint32_t loop_id,
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index d2b5237..fdfc3b3 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -216,6 +216,10 @@
functions_.back().push_var(operands);
}
+ /// @returns true if the current instruction insertion point is
+ /// inside a basic block.
+ bool InsideBasicBlock() const;
+
/// Converts a storage class to a SPIR-V storage class.
/// @param klass the storage class to convert
/// @returns the SPIR-V storage class or SpvStorageClassMax on error.
diff --git a/src/writer/spirv/builder_if_test.cc b/src/writer/spirv/builder_if_test.cc
index 6273bee..2fda0e1 100644
--- a/src/writer/spirv/builder_if_test.cc
+++ b/src/writer/spirv/builder_if_test.cc
@@ -603,6 +603,91 @@
)");
}
+TEST_F(BuilderTest, If_ElseIf_WithReturn) {
+ // crbug.com/tint/1315
+ // if (false) {
+ // } else if (true) {
+ // return;
+ // }
+
+ auto* if_stmt = If(Expr(false), Block(),
+ ast::ElseStatementList{Else(Expr(true), Block(Return()))});
+ auto* fn = Func("f", {}, ty.void_(), {if_stmt});
+
+ spirv::Builder& b = Build();
+
+ EXPECT_TRUE(b.GenerateFunction(fn)) << b.error();
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeVoid
+%1 = OpTypeFunction %2
+%5 = OpTypeBool
+%6 = OpConstantFalse %5
+%10 = OpConstantTrue %5
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(OpSelectionMerge %7 None
+OpBranchConditional %6 %8 %9
+%8 = OpLabel
+OpBranch %7
+%9 = OpLabel
+OpSelectionMerge %11 None
+OpBranchConditional %10 %12 %11
+%12 = OpLabel
+OpReturn
+%11 = OpLabel
+OpBranch %7
+%7 = OpLabel
+OpReturn
+)");
+}
+
+TEST_F(BuilderTest, Loop_If_ElseIf_WithBreak) {
+ // crbug.com/tint/1315
+ // loop {
+ // if (false) {
+ // } else if (true) {
+ // break;
+ // }
+ // }
+
+ auto* if_stmt = If(Expr(false), Block(),
+ ast::ElseStatementList{Else(Expr(true), Block(Break()))});
+ auto* fn = Func("f", {}, ty.void_(), {Loop(Block(if_stmt))});
+
+ spirv::Builder& b = Build();
+
+ EXPECT_TRUE(b.GenerateFunction(fn)) << b.error();
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeVoid
+%1 = OpTypeFunction %2
+%9 = OpTypeBool
+%10 = OpConstantFalse %9
+%14 = OpConstantTrue %9
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(OpBranch %5
+%5 = OpLabel
+OpLoopMerge %6 %7 None
+OpBranch %8
+%8 = OpLabel
+OpSelectionMerge %11 None
+OpBranchConditional %10 %12 %13
+%12 = OpLabel
+OpBranch %11
+%13 = OpLabel
+OpSelectionMerge %15 None
+OpBranchConditional %14 %16 %15
+%16 = OpLabel
+OpBranch %6
+%15 = OpLabel
+OpBranch %11
+%11 = OpLabel
+OpBranch %7
+%7 = OpLabel
+OpBranch %5
+%6 = OpLabel
+OpReturn
+)");
+}
+
} // namespace
} // namespace spirv
} // namespace writer
diff --git a/src/writer/spirv/function.h b/src/writer/spirv/function.h
index 926f33f..df747ec 100644
--- a/src/writer/spirv/function.h
+++ b/src/writer/spirv/function.h
@@ -32,7 +32,7 @@
/// Constructor
/// @param declaration the function declaration
- /// @param label_op the operand for the initial function label
+ /// @param label_op the operand for function's entry block label
/// @param params the function parameters
Function(const Instruction& declaration,
const Operand& label_op,
@@ -49,7 +49,7 @@
/// @returns the declaration
const Instruction& declaration() const { return declaration_; }
- /// @returns the function label id
+ /// @returns the label ID for the function entry block
uint32_t label_id() const { return label_op_.to_i(); }
/// Adds an instruction to the instruction list