[spirv-reader] Compute basic block order

Test non-nested sequences and selections.

Bug: tint:3
Change-Id: Ibbbcd428d701d9e7d4da1682f94c2bdbef00121b
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/19920
Reviewed-by: dan sinclair <dsinclair@google.com>
diff --git a/BUILD.gn b/BUILD.gn
index 3d68847..c79492c 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -511,6 +511,7 @@
     "src/reader/spirv/enum_converter_test.cc",
     "src/reader/spirv/fail_stream_test.cc",
     "src/reader/spirv/function_arithmetic_test.cc",
+    "src/reader/spirv/function_cfg_test.cc",
     "src/reader/spirv/function_conversion_test.cc",
     "src/reader/spirv/function_decl_test.cc",
     "src/reader/spirv/function_logical_test.cc",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 8a490be..1aa773f 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -326,6 +326,7 @@
     reader/spirv/enum_converter_test.cc
     reader/spirv/fail_stream_test.cc
     reader/spirv/function_arithmetic_test.cc
+    reader/spirv/function_cfg_test.cc
     reader/spirv/function_conversion_test.cc
     reader/spirv/function_decl_test.cc
     reader/spirv/function_logical_test.cc
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 9e1d8ef..97c3d33 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -14,7 +14,10 @@
 
 #include "src/reader/spirv/function.h"
 
+#include <unordered_map>
+#include <unordered_set>
 #include <utility>
+#include <vector>
 
 #include "source/opt/basic_block.h"
 #include "source/opt/function.h"
@@ -112,6 +115,97 @@
   return ast::BinaryOp::kNone;
 }
 
+// @returns the merge block ID for the given basic block, or 0 if there is none.
+uint32_t MergeFor(const spvtools::opt::BasicBlock& bb) {
+  // Get the OpSelectionMerge or OpLoopMerge instruction, if any.
+  auto* inst = bb.GetMergeInst();
+  return inst == nullptr ? 0 : inst->GetSingleWordInOperand(0);
+}
+
+// @returns the continue target ID for the given basic block, or 0 if there
+// is none.
+uint32_t ContinueTargetFor(const spvtools::opt::BasicBlock& bb) {
+  // Get the OpLoopMerge instruction, if any.
+  auto* inst = bb.GetLoopMergeInst();
+  return inst == nullptr ? 0 : inst->GetSingleWordInOperand(1);
+}
+
+// A structured traverser produces the reverse structured post-order of the
+// CFG of a function.  The blocks traversed are the transitive closure (minimum
+// fixed point) of:
+//  - the entry block
+//  - a block reached by a branch from another block in the set
+//  - a block mentioned as a merge block or continue target for a block in the
+//  set
+class StructuredTraverser {
+ public:
+  explicit StructuredTraverser(const spvtools::opt::Function& function)
+      : function_(function) {
+    for (auto& block : function_) {
+      id_to_block_[block.id()] = &block;
+    }
+  }
+
+  // Returns the reverse postorder traversal of the CFG, where:
+  //  - a merge block always follows its associated constructs
+  //  - a continue target always follows the associated loop construct, if any
+  // @returns the IDs of blocks in reverse structured post order
+  std::vector<uint32_t> ReverseStructuredPostOrder() {
+    visit_order_.clear();
+    visited_.clear();
+    VisitBackward(function_.entry()->id());
+
+    std::vector<uint32_t> order(visit_order_.rbegin(), visit_order_.rend());
+    return order;
+  }
+
+ private:
+  // Executes a depth first search of the CFG, where right after we visit a
+  // header, we will visit its merge block, then its continue target (if any).
+  // Also records the post order ordering.
+  void VisitBackward(uint32_t id) {
+    if (id == 0)
+      return;
+    if (visited_.count(id))
+      return;
+    visited_.insert(id);
+
+    const spvtools::opt::BasicBlock* bb =
+        id_to_block_[id];  // non-null for valid modules
+    VisitBackward(MergeFor(*bb));
+    VisitBackward(ContinueTargetFor(*bb));
+
+    // Visit successors. We will naturally skip the continue target and merge
+    // blocks.
+    auto* terminator = bb->terminator();
+    auto opcode = terminator->opcode();
+    if (opcode == SpvOpBranchConditional) {
+      // Visit the false branch, then the true branch, to make them come
+      // out in the natural order for an "if".
+      VisitBackward(terminator->GetSingleWordInOperand(2));
+      VisitBackward(terminator->GetSingleWordInOperand(1));
+    } else if (opcode == SpvOpBranch) {
+      VisitBackward(terminator->GetSingleWordInOperand(0));
+    } else if (opcode == SpvOpSwitch) {
+      // TODO(dneto): Consider visiting the labels in literal-value order.
+      std::vector<uint32_t> successors;
+      bb->ForEachSuccessorLabel([&successors](const uint32_t succ_id) {
+        successors.push_back(succ_id);
+      });
+      for (auto succ_id : successors) {
+        VisitBackward(succ_id);
+      }
+    }
+
+    visit_order_.push_back(id);
+  }
+
+  const spvtools::opt::Function& function_;
+  std::unordered_map<uint32_t, const spvtools::opt::BasicBlock*> id_to_block_;
+  std::vector<uint32_t> visit_order_;
+  std::unordered_set<uint32_t> visited_;
+};
+
 }  // namespace
 
 FunctionEmitter::FunctionEmitter(ParserImpl* pi,
@@ -213,6 +307,8 @@
 }
 
 bool FunctionEmitter::EmitBody() {
+  ComputeBlockOrderAndPositions();
+
   if (!EmitFunctionVariables()) {
     return false;
   }
@@ -222,6 +318,18 @@
   return success();
 }
 
+void FunctionEmitter::ComputeBlockOrderAndPositions() {
+  for (auto& block : function_) {
+    block_info_[block.id()] = std::make_unique<BlockInfo>(block);
+  }
+
+  rspo_ = StructuredTraverser(function_).ReverseStructuredPostOrder();
+
+  for (uint32_t i = 0; i < rspo_.size(); ++i) {
+    GetBlockInfo(rspo_[i])->pos = i;
+  }
+}
+
 bool FunctionEmitter::EmitFunctionVariables() {
   if (failed()) {
     return false;
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index 433b078..48367a3 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -36,6 +36,23 @@
 namespace reader {
 namespace spirv {
 
+/// Bookkeeping info for a basic block.
+struct BlockInfo {
+  /// Constructor
+  /// @param bb internal representation of the basic block
+  explicit BlockInfo(const spvtools::opt::BasicBlock& bb)
+      : basic_block(&bb), id(bb.id()) {}
+
+  /// The internal representation of the basic block.
+  const spvtools::opt::BasicBlock* basic_block;
+
+  /// The ID of the OpLabel instruction that starts this block.
+  uint32_t id = 0;
+
+  /// The position of this block in the reverse structured post-order.
+  uint32_t pos = 0;
+};
+
 /// A FunctionEmitter emits a SPIR-V function onto a Tint AST module.
 class FunctionEmitter {
  public:
@@ -73,6 +90,14 @@
   /// @returns false if emission failed.
   bool EmitBody();
 
+  /// Determines the output order for the basic blocks in the function.
+  /// Populates |rspo_| and the |pos| block info member.
+  void ComputeBlockOrderAndPositions();
+
+  /// @returns the reverse structured post order of the basic blocks in
+  /// the function.
+  const std::vector<uint32_t>& rspo() const { return rspo_; }
+
   /// Emits declarations of function variables.
   /// @returns false if emission failed.
   bool EmitFunctionVariables();
@@ -116,6 +141,16 @@
   TypedExpression MaybeEmitCombinatorialValue(
       const spvtools::opt::Instruction& inst);
 
+  /// Gets the block info for a block ID, if any exists
+  /// @param id the SPIR-V ID of the OpLabel instruction starting the block
+  /// @returns the block info for the given ID, if it exists, or nullptr
+  BlockInfo* GetBlockInfo(uint32_t id) {
+    auto where = block_info_.find(id);
+    if (where == block_info_.end())
+      return nullptr;
+    return where->second.get();
+  }
+
  private:
   /// @returns the store type for the OpVariable instruction, or
   /// null on failure.
@@ -136,6 +171,13 @@
   std::unordered_set<uint32_t> identifier_values_;
   // Mapping from SPIR-V ID that is used at most once, to its AST expression.
   std::unordered_map<uint32_t, TypedExpression> singly_used_values_;
+
+  // The IDs of basic blocks, in reverse structured post-order (RSPO).
+  // This is the output order for the basic blocks.
+  std::vector<uint32_t> rspo_;
+
+  // Mapping from block ID to its bookkeeping info.
+  std::unordered_map<uint32_t, std::unique_ptr<BlockInfo>> block_info_;
 };
 
 }  // namespace spirv
diff --git a/src/reader/spirv/function_cfg_test.cc b/src/reader/spirv/function_cfg_test.cc
new file mode 100644
index 0000000..d6b6d12
--- /dev/null
+++ b/src/reader/spirv/function_cfg_test.cc
@@ -0,0 +1,413 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "src/reader/spirv/function.h"
+#include "src/reader/spirv/parser_impl.h"
+#include "src/reader/spirv/parser_impl_test_helper.h"
+#include "src/reader/spirv/spirv_tools_helpers_test.h"
+
+namespace tint {
+namespace reader {
+namespace spirv {
+namespace {
+
+using ::testing::ElementsAre;
+
+std::string CommonTypes() {
+  return R"(
+    %void = OpTypeVoid
+    %voidfn = OpTypeFunction %void
+
+    %bool = OpTypeBool
+    %cond = OpUndef %bool
+
+    %uint = OpTypeInt 32 0
+    %selector = OpUndef %uint
+  )";
+}
+
+TEST_F(SpvParserTest, ComputeBlockOrder_OneBlock) {
+  auto* p = parser(test::Assemble(CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %42 = OpLabel
+     OpReturn
+
+     OpFunctionEnd
+  )"));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  fe.ComputeBlockOrderAndPositions();
+
+  EXPECT_THAT(fe.rspo(), ElementsAre(42));
+}
+
+TEST_F(SpvParserTest, ComputeBlockOrder_IgnoreStaticalyUnreachable) {
+  auto* p = parser(test::Assemble(CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpBranch %20
+
+     %15 = OpLabel ; statically dead
+     OpReturn
+
+     %20 = OpLabel
+     OpReturn
+
+     OpFunctionEnd
+  )"));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  fe.ComputeBlockOrderAndPositions();
+
+  EXPECT_THAT(fe.rspo(), ElementsAre(10, 20));
+}
+
+TEST_F(SpvParserTest, ComputeBlockOrder_ReorderSequence) {
+  auto* p = parser(test::Assemble(CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpBranch %20
+
+     %30 = OpLabel
+     OpReturn
+
+     %20 = OpLabel
+     OpBranch %30 ; backtrack
+
+     OpFunctionEnd
+  )"));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  fe.ComputeBlockOrderAndPositions();
+
+  EXPECT_THAT(fe.rspo(), ElementsAre(10, 20, 30));
+}
+
+TEST_F(SpvParserTest, ComputeBlockOrder_RespectConditionalBranchOrder) {
+  auto* p = parser(test::Assemble(CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpSelectionMerge %99 None
+     OpBranchConditional %cond %20 %30
+
+     %99 = OpLabel
+     OpReturn
+
+     %30 = OpLabel
+     OpReturn
+
+     %20 = OpLabel
+     OpBranch %99
+
+     OpFunctionEnd
+  )"));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  fe.ComputeBlockOrderAndPositions();
+
+  EXPECT_THAT(fe.rspo(), ElementsAre(10, 20, 30, 99));
+}
+
+TEST_F(SpvParserTest, ComputeBlockOrder_TrueOnlyBranch) {
+  auto* p = parser(test::Assemble(CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpSelectionMerge %99 None
+     OpBranchConditional %cond %20 %99
+
+     %99 = OpLabel
+     OpReturn
+
+     %20 = OpLabel
+     OpBranch %99
+
+     OpFunctionEnd
+  )"));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  fe.ComputeBlockOrderAndPositions();
+
+  EXPECT_THAT(fe.rspo(), ElementsAre(10, 20, 99));
+}
+
+TEST_F(SpvParserTest, ComputeBlockOrder_FalseOnlyBranch) {
+  auto* p = parser(test::Assemble(CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpSelectionMerge %99 None
+     OpBranchConditional %cond %99 %20
+
+     %99 = OpLabel
+     OpReturn
+
+     %20 = OpLabel
+     OpBranch %99
+
+     OpFunctionEnd
+  )"));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  fe.ComputeBlockOrderAndPositions();
+
+  EXPECT_THAT(fe.rspo(), ElementsAre(10, 20, 99));
+}
+
+TEST_F(SpvParserTest, ComputeBlockOrder_SwitchOrderNaturallyReversed) {
+  auto* p = parser(test::Assemble(CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpSelectionMerge %99 None
+     OpSwitch %selector %99 20 %20 30 %30
+
+     %99 = OpLabel
+     OpReturn
+
+     %30 = OpLabel
+     OpReturn
+
+     %20 = OpLabel
+     OpBranch %99
+
+     OpFunctionEnd
+  )"));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  fe.ComputeBlockOrderAndPositions();
+
+  EXPECT_THAT(fe.rspo(), ElementsAre(10, 30, 20, 99));
+}
+
+TEST_F(SpvParserTest,
+       ComputeBlockOrder_SwitchWithDefaultOrderNaturallyReversed) {
+  auto* p = parser(test::Assemble(CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpSelectionMerge %99 None
+     OpSwitch %selector %80 20 %20 30 %30
+
+     %80 = OpLabel ; the default case
+     OpBranch %99
+
+     %99 = OpLabel
+     OpReturn
+
+     %30 = OpLabel
+     OpReturn
+
+     %20 = OpLabel
+     OpBranch %99
+
+     OpFunctionEnd
+  )"));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  fe.ComputeBlockOrderAndPositions();
+
+  EXPECT_THAT(fe.rspo(), ElementsAre(10, 30, 20, 80, 99));
+}
+
+TEST_F(SpvParserTest, ComputeBlockOrder_RespectSwitchCaseFallthrough) {
+  auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpSelectionMerge %99 None
+     OpSwitch %selector %99 20 %20 30 %30 40 %40 50 %50
+
+     %50 = OpLabel
+     OpBranch %99
+
+     %99 = OpLabel
+     OpReturn
+
+     %40 = OpLabel
+     OpBranch %99
+
+     %30 = OpLabel
+     OpBranch %50 ; fallthrough
+
+     %20 = OpLabel
+     OpBranch %40 ; fallthrough
+
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  fe.ComputeBlockOrderAndPositions();
+
+  EXPECT_THAT(fe.rspo(), ElementsAre(10, 30, 50, 20, 40, 99)) << assembly;
+}
+
+TEST_F(SpvParserTest,
+       ComputeBlockOrder_RespectSwitchCaseFallthrough_FromDefault) {
+  auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpSelectionMerge %99 None
+     OpSwitch %selector %80 20 %20 30 %30 40 %40
+
+     %80 = OpLabel ; the default case
+     OpBranch %30 ; fallthrough to another case
+
+     %99 = OpLabel
+     OpReturn
+
+     %40 = OpLabel
+     OpBranch %99
+
+     %30 = OpLabel
+     OpBranch %40
+
+     %20 = OpLabel
+     OpBranch %99
+
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  fe.ComputeBlockOrderAndPositions();
+
+  EXPECT_THAT(fe.rspo(), ElementsAre(10, 20, 80, 30, 40, 99)) << assembly;
+}
+
+TEST_F(SpvParserTest,
+       ComputeBlockOrder_RespectSwitchCaseFallthrough_FromCaseToDefaultToCase) {
+  auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpSelectionMerge %99 None
+     OpSwitch %selector %80 20 %20 30 %30
+
+     %99 = OpLabel
+     OpReturn
+
+     %20 = OpLabel
+     OpBranch %80 ; fallthrough to default
+
+     %80 = OpLabel ; the default case
+     OpBranch %30 ; fallthrough to 30
+
+     %30 = OpLabel
+     OpBranch %99
+
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  fe.ComputeBlockOrderAndPositions();
+
+  EXPECT_THAT(fe.rspo(), ElementsAre(10, 20, 80, 30, 99)) << assembly;
+}
+
+TEST_F(SpvParserTest,
+       ComputeBlockOrder_SwitchCasesFallthrough_OppositeDirections) {
+  auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpSelectionMerge %99 None
+     OpSwitch %selector %99 20 %20 30 %30 40 %40 50 %50
+
+     %99 = OpLabel
+     OpReturn
+
+     %20 = OpLabel
+     OpBranch %30 ; forward
+
+     %40 = OpLabel
+     OpBranch %99
+
+     %30 = OpLabel
+     OpBranch %99
+
+     ; SPIR-V doesn't actually allow a fall-through that goes backward in the
+     ; module. But the block ordering algorithm tolerates it.
+     %50 = OpLabel
+     OpBranch %40 ; backward
+
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  fe.ComputeBlockOrderAndPositions();
+
+  EXPECT_THAT(fe.rspo(), ElementsAre(10, 50, 40, 20, 30, 99)) << assembly;
+}
+
+TEST_F(SpvParserTest,
+       ComputeBlockOrder_RespectSwitchCaseFallthrough_Interleaved) {
+  auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpSelectionMerge %99 None
+     OpSwitch %selector %99 20 %20 30 %30 40 %40 50 %50
+
+     %99 = OpLabel
+     OpReturn
+
+     %20 = OpLabel
+     OpBranch %40
+
+     %30 = OpLabel
+     OpBranch %50
+
+     %40 = OpLabel
+     OpBranch %60
+
+     %50 = OpLabel
+     OpBranch %70
+
+     %60 = OpLabel
+     OpBranch %99
+
+     %70 = OpLabel
+     OpBranch %99
+
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  fe.ComputeBlockOrderAndPositions();
+
+  EXPECT_THAT(fe.rspo(), ElementsAre(10, 30, 50, 70, 20, 40, 60, 99))
+      << assembly;
+}
+
+// TODO(dneto): test nesting
+// TODO(dneto): test loops
+
+}  // namespace
+}  // namespace spirv
+}  // namespace reader
+}  // namespace tint