[spirv-reader] Compute basic block order
Test non-nested sequences and selections.
Bug: tint:3
Change-Id: Ibbbcd428d701d9e7d4da1682f94c2bdbef00121b
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/19920
Reviewed-by: dan sinclair <dsinclair@google.com>
diff --git a/BUILD.gn b/BUILD.gn
index 3d68847..c79492c 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -511,6 +511,7 @@
"src/reader/spirv/enum_converter_test.cc",
"src/reader/spirv/fail_stream_test.cc",
"src/reader/spirv/function_arithmetic_test.cc",
+ "src/reader/spirv/function_cfg_test.cc",
"src/reader/spirv/function_conversion_test.cc",
"src/reader/spirv/function_decl_test.cc",
"src/reader/spirv/function_logical_test.cc",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 8a490be..1aa773f 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -326,6 +326,7 @@
reader/spirv/enum_converter_test.cc
reader/spirv/fail_stream_test.cc
reader/spirv/function_arithmetic_test.cc
+ reader/spirv/function_cfg_test.cc
reader/spirv/function_conversion_test.cc
reader/spirv/function_decl_test.cc
reader/spirv/function_logical_test.cc
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 9e1d8ef..97c3d33 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -14,7 +14,10 @@
#include "src/reader/spirv/function.h"
+#include <unordered_map>
+#include <unordered_set>
#include <utility>
+#include <vector>
#include "source/opt/basic_block.h"
#include "source/opt/function.h"
@@ -112,6 +115,97 @@
return ast::BinaryOp::kNone;
}
+// @returns the merge block ID for the given basic block, or 0 if there is none.
+uint32_t MergeFor(const spvtools::opt::BasicBlock& bb) {
+ // Get the OpSelectionMerge or OpLoopMerge instruction, if any.
+ auto* inst = bb.GetMergeInst();
+ return inst == nullptr ? 0 : inst->GetSingleWordInOperand(0);
+}
+
+// @returns the continue target ID for the given basic block, or 0 if there
+// is none.
+uint32_t ContinueTargetFor(const spvtools::opt::BasicBlock& bb) {
+ // Get the OpLoopMerge instruction, if any.
+ auto* inst = bb.GetLoopMergeInst();
+ return inst == nullptr ? 0 : inst->GetSingleWordInOperand(1);
+}
+
+// A structured traverser produces the reverse structured post-order of the
+// CFG of a function. The blocks traversed are the transitive closure (minimum
+// fixed point) of:
+// - the entry block
+// - a block reached by a branch from another block in the set
+// - a block mentioned as a merge block or continue target for a block in the
+// set
+class StructuredTraverser {
+ public:
+ explicit StructuredTraverser(const spvtools::opt::Function& function)
+ : function_(function) {
+ for (auto& block : function_) {
+ id_to_block_[block.id()] = █
+ }
+ }
+
+ // Returns the reverse postorder traversal of the CFG, where:
+ // - a merge block always follows its associated constructs
+ // - a continue target always follows the associated loop construct, if any
+ // @returns the IDs of blocks in reverse structured post order
+ std::vector<uint32_t> ReverseStructuredPostOrder() {
+ visit_order_.clear();
+ visited_.clear();
+ VisitBackward(function_.entry()->id());
+
+ std::vector<uint32_t> order(visit_order_.rbegin(), visit_order_.rend());
+ return order;
+ }
+
+ private:
+ // Executes a depth first search of the CFG, where right after we visit a
+ // header, we will visit its merge block, then its continue target (if any).
+ // Also records the post order ordering.
+ void VisitBackward(uint32_t id) {
+ if (id == 0)
+ return;
+ if (visited_.count(id))
+ return;
+ visited_.insert(id);
+
+ const spvtools::opt::BasicBlock* bb =
+ id_to_block_[id]; // non-null for valid modules
+ VisitBackward(MergeFor(*bb));
+ VisitBackward(ContinueTargetFor(*bb));
+
+ // Visit successors. We will naturally skip the continue target and merge
+ // blocks.
+ auto* terminator = bb->terminator();
+ auto opcode = terminator->opcode();
+ if (opcode == SpvOpBranchConditional) {
+ // Visit the false branch, then the true branch, to make them come
+ // out in the natural order for an "if".
+ VisitBackward(terminator->GetSingleWordInOperand(2));
+ VisitBackward(terminator->GetSingleWordInOperand(1));
+ } else if (opcode == SpvOpBranch) {
+ VisitBackward(terminator->GetSingleWordInOperand(0));
+ } else if (opcode == SpvOpSwitch) {
+ // TODO(dneto): Consider visiting the labels in literal-value order.
+ std::vector<uint32_t> successors;
+ bb->ForEachSuccessorLabel([&successors](const uint32_t succ_id) {
+ successors.push_back(succ_id);
+ });
+ for (auto succ_id : successors) {
+ VisitBackward(succ_id);
+ }
+ }
+
+ visit_order_.push_back(id);
+ }
+
+ const spvtools::opt::Function& function_;
+ std::unordered_map<uint32_t, const spvtools::opt::BasicBlock*> id_to_block_;
+ std::vector<uint32_t> visit_order_;
+ std::unordered_set<uint32_t> visited_;
+};
+
} // namespace
FunctionEmitter::FunctionEmitter(ParserImpl* pi,
@@ -213,6 +307,8 @@
}
bool FunctionEmitter::EmitBody() {
+ ComputeBlockOrderAndPositions();
+
if (!EmitFunctionVariables()) {
return false;
}
@@ -222,6 +318,18 @@
return success();
}
+void FunctionEmitter::ComputeBlockOrderAndPositions() {
+ for (auto& block : function_) {
+ block_info_[block.id()] = std::make_unique<BlockInfo>(block);
+ }
+
+ rspo_ = StructuredTraverser(function_).ReverseStructuredPostOrder();
+
+ for (uint32_t i = 0; i < rspo_.size(); ++i) {
+ GetBlockInfo(rspo_[i])->pos = i;
+ }
+}
+
bool FunctionEmitter::EmitFunctionVariables() {
if (failed()) {
return false;
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index 433b078..48367a3 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -36,6 +36,23 @@
namespace reader {
namespace spirv {
+/// Bookkeeping info for a basic block.
+struct BlockInfo {
+ /// Constructor
+ /// @param bb internal representation of the basic block
+ explicit BlockInfo(const spvtools::opt::BasicBlock& bb)
+ : basic_block(&bb), id(bb.id()) {}
+
+ /// The internal representation of the basic block.
+ const spvtools::opt::BasicBlock* basic_block;
+
+ /// The ID of the OpLabel instruction that starts this block.
+ uint32_t id = 0;
+
+ /// The position of this block in the reverse structured post-order.
+ uint32_t pos = 0;
+};
+
/// A FunctionEmitter emits a SPIR-V function onto a Tint AST module.
class FunctionEmitter {
public:
@@ -73,6 +90,14 @@
/// @returns false if emission failed.
bool EmitBody();
+ /// Determines the output order for the basic blocks in the function.
+ /// Populates |rspo_| and the |pos| block info member.
+ void ComputeBlockOrderAndPositions();
+
+ /// @returns the reverse structured post order of the basic blocks in
+ /// the function.
+ const std::vector<uint32_t>& rspo() const { return rspo_; }
+
/// Emits declarations of function variables.
/// @returns false if emission failed.
bool EmitFunctionVariables();
@@ -116,6 +141,16 @@
TypedExpression MaybeEmitCombinatorialValue(
const spvtools::opt::Instruction& inst);
+ /// Gets the block info for a block ID, if any exists
+ /// @param id the SPIR-V ID of the OpLabel instruction starting the block
+ /// @returns the block info for the given ID, if it exists, or nullptr
+ BlockInfo* GetBlockInfo(uint32_t id) {
+ auto where = block_info_.find(id);
+ if (where == block_info_.end())
+ return nullptr;
+ return where->second.get();
+ }
+
private:
/// @returns the store type for the OpVariable instruction, or
/// null on failure.
@@ -136,6 +171,13 @@
std::unordered_set<uint32_t> identifier_values_;
// Mapping from SPIR-V ID that is used at most once, to its AST expression.
std::unordered_map<uint32_t, TypedExpression> singly_used_values_;
+
+ // The IDs of basic blocks, in reverse structured post-order (RSPO).
+ // This is the output order for the basic blocks.
+ std::vector<uint32_t> rspo_;
+
+ // Mapping from block ID to its bookkeeping info.
+ std::unordered_map<uint32_t, std::unique_ptr<BlockInfo>> block_info_;
};
} // namespace spirv
diff --git a/src/reader/spirv/function_cfg_test.cc b/src/reader/spirv/function_cfg_test.cc
new file mode 100644
index 0000000..d6b6d12
--- /dev/null
+++ b/src/reader/spirv/function_cfg_test.cc
@@ -0,0 +1,413 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "src/reader/spirv/function.h"
+#include "src/reader/spirv/parser_impl.h"
+#include "src/reader/spirv/parser_impl_test_helper.h"
+#include "src/reader/spirv/spirv_tools_helpers_test.h"
+
+namespace tint {
+namespace reader {
+namespace spirv {
+namespace {
+
+using ::testing::ElementsAre;
+
+std::string CommonTypes() {
+ return R"(
+ %void = OpTypeVoid
+ %voidfn = OpTypeFunction %void
+
+ %bool = OpTypeBool
+ %cond = OpUndef %bool
+
+ %uint = OpTypeInt 32 0
+ %selector = OpUndef %uint
+ )";
+}
+
+TEST_F(SpvParserTest, ComputeBlockOrder_OneBlock) {
+ auto* p = parser(test::Assemble(CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+
+ %42 = OpLabel
+ OpReturn
+
+ OpFunctionEnd
+ )"));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ fe.ComputeBlockOrderAndPositions();
+
+ EXPECT_THAT(fe.rspo(), ElementsAre(42));
+}
+
+TEST_F(SpvParserTest, ComputeBlockOrder_IgnoreStaticalyUnreachable) {
+ auto* p = parser(test::Assemble(CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ OpBranch %20
+
+ %15 = OpLabel ; statically dead
+ OpReturn
+
+ %20 = OpLabel
+ OpReturn
+
+ OpFunctionEnd
+ )"));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ fe.ComputeBlockOrderAndPositions();
+
+ EXPECT_THAT(fe.rspo(), ElementsAre(10, 20));
+}
+
+TEST_F(SpvParserTest, ComputeBlockOrder_ReorderSequence) {
+ auto* p = parser(test::Assemble(CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ OpBranch %20
+
+ %30 = OpLabel
+ OpReturn
+
+ %20 = OpLabel
+ OpBranch %30 ; backtrack
+
+ OpFunctionEnd
+ )"));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ fe.ComputeBlockOrderAndPositions();
+
+ EXPECT_THAT(fe.rspo(), ElementsAre(10, 20, 30));
+}
+
+TEST_F(SpvParserTest, ComputeBlockOrder_RespectConditionalBranchOrder) {
+ auto* p = parser(test::Assemble(CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ OpSelectionMerge %99 None
+ OpBranchConditional %cond %20 %30
+
+ %99 = OpLabel
+ OpReturn
+
+ %30 = OpLabel
+ OpReturn
+
+ %20 = OpLabel
+ OpBranch %99
+
+ OpFunctionEnd
+ )"));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ fe.ComputeBlockOrderAndPositions();
+
+ EXPECT_THAT(fe.rspo(), ElementsAre(10, 20, 30, 99));
+}
+
+TEST_F(SpvParserTest, ComputeBlockOrder_TrueOnlyBranch) {
+ auto* p = parser(test::Assemble(CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ OpSelectionMerge %99 None
+ OpBranchConditional %cond %20 %99
+
+ %99 = OpLabel
+ OpReturn
+
+ %20 = OpLabel
+ OpBranch %99
+
+ OpFunctionEnd
+ )"));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ fe.ComputeBlockOrderAndPositions();
+
+ EXPECT_THAT(fe.rspo(), ElementsAre(10, 20, 99));
+}
+
+TEST_F(SpvParserTest, ComputeBlockOrder_FalseOnlyBranch) {
+ auto* p = parser(test::Assemble(CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ OpSelectionMerge %99 None
+ OpBranchConditional %cond %99 %20
+
+ %99 = OpLabel
+ OpReturn
+
+ %20 = OpLabel
+ OpBranch %99
+
+ OpFunctionEnd
+ )"));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ fe.ComputeBlockOrderAndPositions();
+
+ EXPECT_THAT(fe.rspo(), ElementsAre(10, 20, 99));
+}
+
+TEST_F(SpvParserTest, ComputeBlockOrder_SwitchOrderNaturallyReversed) {
+ auto* p = parser(test::Assemble(CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ OpSelectionMerge %99 None
+ OpSwitch %selector %99 20 %20 30 %30
+
+ %99 = OpLabel
+ OpReturn
+
+ %30 = OpLabel
+ OpReturn
+
+ %20 = OpLabel
+ OpBranch %99
+
+ OpFunctionEnd
+ )"));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ fe.ComputeBlockOrderAndPositions();
+
+ EXPECT_THAT(fe.rspo(), ElementsAre(10, 30, 20, 99));
+}
+
+TEST_F(SpvParserTest,
+ ComputeBlockOrder_SwitchWithDefaultOrderNaturallyReversed) {
+ auto* p = parser(test::Assemble(CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ OpSelectionMerge %99 None
+ OpSwitch %selector %80 20 %20 30 %30
+
+ %80 = OpLabel ; the default case
+ OpBranch %99
+
+ %99 = OpLabel
+ OpReturn
+
+ %30 = OpLabel
+ OpReturn
+
+ %20 = OpLabel
+ OpBranch %99
+
+ OpFunctionEnd
+ )"));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ fe.ComputeBlockOrderAndPositions();
+
+ EXPECT_THAT(fe.rspo(), ElementsAre(10, 30, 20, 80, 99));
+}
+
+TEST_F(SpvParserTest, ComputeBlockOrder_RespectSwitchCaseFallthrough) {
+ auto assembly = CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ OpSelectionMerge %99 None
+ OpSwitch %selector %99 20 %20 30 %30 40 %40 50 %50
+
+ %50 = OpLabel
+ OpBranch %99
+
+ %99 = OpLabel
+ OpReturn
+
+ %40 = OpLabel
+ OpBranch %99
+
+ %30 = OpLabel
+ OpBranch %50 ; fallthrough
+
+ %20 = OpLabel
+ OpBranch %40 ; fallthrough
+
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ fe.ComputeBlockOrderAndPositions();
+
+ EXPECT_THAT(fe.rspo(), ElementsAre(10, 30, 50, 20, 40, 99)) << assembly;
+}
+
+TEST_F(SpvParserTest,
+ ComputeBlockOrder_RespectSwitchCaseFallthrough_FromDefault) {
+ auto assembly = CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ OpSelectionMerge %99 None
+ OpSwitch %selector %80 20 %20 30 %30 40 %40
+
+ %80 = OpLabel ; the default case
+ OpBranch %30 ; fallthrough to another case
+
+ %99 = OpLabel
+ OpReturn
+
+ %40 = OpLabel
+ OpBranch %99
+
+ %30 = OpLabel
+ OpBranch %40
+
+ %20 = OpLabel
+ OpBranch %99
+
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ fe.ComputeBlockOrderAndPositions();
+
+ EXPECT_THAT(fe.rspo(), ElementsAre(10, 20, 80, 30, 40, 99)) << assembly;
+}
+
+TEST_F(SpvParserTest,
+ ComputeBlockOrder_RespectSwitchCaseFallthrough_FromCaseToDefaultToCase) {
+ auto assembly = CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ OpSelectionMerge %99 None
+ OpSwitch %selector %80 20 %20 30 %30
+
+ %99 = OpLabel
+ OpReturn
+
+ %20 = OpLabel
+ OpBranch %80 ; fallthrough to default
+
+ %80 = OpLabel ; the default case
+ OpBranch %30 ; fallthrough to 30
+
+ %30 = OpLabel
+ OpBranch %99
+
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ fe.ComputeBlockOrderAndPositions();
+
+ EXPECT_THAT(fe.rspo(), ElementsAre(10, 20, 80, 30, 99)) << assembly;
+}
+
+TEST_F(SpvParserTest,
+ ComputeBlockOrder_SwitchCasesFallthrough_OppositeDirections) {
+ auto assembly = CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ OpSelectionMerge %99 None
+ OpSwitch %selector %99 20 %20 30 %30 40 %40 50 %50
+
+ %99 = OpLabel
+ OpReturn
+
+ %20 = OpLabel
+ OpBranch %30 ; forward
+
+ %40 = OpLabel
+ OpBranch %99
+
+ %30 = OpLabel
+ OpBranch %99
+
+ ; SPIR-V doesn't actually allow a fall-through that goes backward in the
+ ; module. But the block ordering algorithm tolerates it.
+ %50 = OpLabel
+ OpBranch %40 ; backward
+
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ fe.ComputeBlockOrderAndPositions();
+
+ EXPECT_THAT(fe.rspo(), ElementsAre(10, 50, 40, 20, 30, 99)) << assembly;
+}
+
+TEST_F(SpvParserTest,
+ ComputeBlockOrder_RespectSwitchCaseFallthrough_Interleaved) {
+ auto assembly = CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ OpSelectionMerge %99 None
+ OpSwitch %selector %99 20 %20 30 %30 40 %40 50 %50
+
+ %99 = OpLabel
+ OpReturn
+
+ %20 = OpLabel
+ OpBranch %40
+
+ %30 = OpLabel
+ OpBranch %50
+
+ %40 = OpLabel
+ OpBranch %60
+
+ %50 = OpLabel
+ OpBranch %70
+
+ %60 = OpLabel
+ OpBranch %99
+
+ %70 = OpLabel
+ OpBranch %99
+
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ fe.ComputeBlockOrderAndPositions();
+
+ EXPECT_THAT(fe.rspo(), ElementsAre(10, 30, 50, 70, 20, 40, 60, 99))
+ << assembly;
+}
+
+// TODO(dneto): test nesting
+// TODO(dneto): test loops
+
+} // namespace
+} // namespace spirv
+} // namespace reader
+} // namespace tint