[spirv-reader] Verify order among header, continue, merge

This is gives us the fundamental ordering of blocks in relation
to a structured construct.

Bug: tint:3
Change-Id: I76eb39403131305398808c33ce4cee256a1c23c2
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/20266
Reviewed-by: dan sinclair <dsinclair@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 426a4b1..27f4b20 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -37,6 +37,105 @@
 #include "src/reader/spirv/fail_stream.h"
 #include "src/reader/spirv/parser_impl.h"
 
+// Terms:
+//    CFG: the control flow graph of the function, where basic blocks are the
+//    nodes, and branches form the directed arcs.  The function entry block is
+//    the root of the CFG.
+//
+//    Suppose H is a header block (i.e. has an OpSelectionMerge or OpLoopMerge).
+//    Then:
+//    - Let M(H) be the merge block named by the merge instruction in H.
+//    - If H is a loop header, i.e. has an OpLoopMerge instruction, then let
+//      CT(H) be the continue target block named by the OpLoopMerge
+//      instruction.
+//    - If H is a selection construct whose header ends in
+//      OpBranchConditional with true target %then and false target %else,
+//      then  TT(H) = %then and FT(H) = %else
+//
+// Determining output block order:
+//    The "structured post-order traversal" of the CFG is a post-order traversal
+//    of the basic blocks in the CFG, where:
+//      We visit the entry node of the function first.
+//      When visiting a header block:
+//        We next visit its merge block
+//        Then if it's a loop header, we next visit the continue target,
+//      Then we visit the block's successors (whether it's a header or not)
+//        If the block ends in an OpBranchConditional, we visit the false target
+//        before the true target.
+//
+//    The "reverse structured post-order traversal" of the CFG is the reverse
+//    of the structured post-order traversal.
+//    This is the order of basic blocks as they should be emitted to the WGSL
+//    function. It is the order computed by ComputeBlockOrder, and stored in
+//    the |FunctionEmiter::block_order_|.
+//    Blocks not in this ordering are ignored by the rest of the algorithm.
+//
+//    Note:
+//     - A block D in the function might not appear in this order because
+//       no block in the order branches to D.
+//     - An unreachable block D might still be in the order because some header
+//       block in the order names D as its continue target, or merge block,
+//       or D is reachable from one of those otherwise-unreachable continue
+//       targets or merge blocks.
+//
+// Terms:
+//    Let Pos(B) be the index position of a block B in the computed block order.
+//
+// CFG intervals and valid nesting:
+//
+//    A correctly structured CFG satisfies nesting rules that we can check by
+//    comparing positions of related blocks.
+//
+//    If header block H is in the block order, then the following holds:
+//
+//      Pos(H) < Pos(M(H))
+//
+//      If CT(H) exists, then:
+//
+//         Pos(H) <= Pos(CT(H)), with equality exactly for single-block loops
+//         Pos(CT(H)) < Pos(M)
+//
+//    This gives us the fundamental ordering of blocks in relation to a
+//    structured construct:
+//      The blocks before H in the block order, are not in the construct
+//      The blocks at M(H) or later in the block order, are not in the construct
+//      The blocks in a selection headed at H are in positions [ Pos(H),
+//      Pos(M(H)) ) The blocks in a loop construct headed at H are in positions
+//      [ Pos(H), Pos(CT(H)) ) The blocks in the continue construct for loop
+//      headed at H are in
+//        positions [ Pos(CT(H)), Pos(M(H)) )
+//
+//      Schematically, for a selection construct headed by H, the blocks are in
+//      order from left to right:
+//
+//                 ...a-b-c H d-e-f M(H) n-o-p...
+//
+//           where ...a-b-c: blocks before the selection construct
+//           where H and d-e-f: blocks in the selection construct
+//           where M(H) and n-o-p...: blocks after the selection construct
+//
+//      Schematically, for a single-block loop construct headed by H, there are
+//      blocks in order from left to right:
+//
+//                 ...a-b-c H M(H) n-o-p...
+//
+//           where ...a-b-c: blocks before the loop
+//           where H is the continue construct; CT(H)=H, and the loop construct
+//           is *empty* where M(H) and n-o-p...: blocks after the loop and
+//           continue constructs
+//
+//      Schematically, for a multi-block loop construct headed by H, there are
+//      blocks in order from left to right:
+//
+//                 ...a-b-c H d-e-f CT(H) j-k-l M(H) n-o-p...
+//
+//           where ...a-b-c: blocks before the loop
+//           where H and d-e-f: blocks in the loop construct
+//           where CT(H) and j-k-l: blocks in the continue construct
+//           where M(H) and n-o-p...: blocks after the loop and continue
+//           constructs
+//
+
 namespace tint {
 namespace reader {
 namespace spirv {
@@ -335,6 +434,9 @@
   }
 
   ComputeBlockOrderAndPositions();
+  if (!VerifyHeaderContinueMergeOrder()) {
+    return false;
+  }
 
   if (!EmitFunctionVariables()) {
     return false;
@@ -493,6 +595,71 @@
   }
 }
 
+bool FunctionEmitter::VerifyHeaderContinueMergeOrder() {
+  // Verify interval rules for a structured header block:
+  //
+  //    If the CFG satisfies structured control flow rules, then:
+  //    If header H is reachable, then the following "interval rules" hold,
+  //    where M(H) is H's merge block, and CT(H) is H's continue target:
+  //
+  //      Pos(H) < Pos(M(H))
+  //
+  //      If CT(H) exists, then:
+  //         Pos(H) <= Pos(CT(H)), with equality exactly for single-block loops
+  //         Pos(CT(H)) < Pos(M)
+  //
+  for (auto block_id : block_order_) {
+    const auto* block_info = GetBlockInfo(block_id);
+    const auto merge = block_info->merge_for_header;
+    if (merge == 0) {
+      continue;
+    }
+      // This is a header.
+      const auto header = block_id;
+      const auto* header_info = block_info;
+      const auto header_pos = header_info->pos;
+      const auto merge_pos = GetBlockInfo(merge)->pos;
+
+      // Pos(H) < Pos(M(H))
+      // Note: When recording merges we made sure H != M(H)
+      if (merge_pos <= header_pos) {
+        return Fail() << "Header " << header
+                      << " does not strictly dominate its merge block "
+                      << merge;
+        // TODO(dneto): Report a path from the entry block to the merge block
+        // without going through the header block.
+      }
+
+      const auto ct = block_info->continue_for_header;
+      if (ct == 0) {
+        continue;
+      }
+      // Furthermore, this is a loop header.
+      const auto* ct_info = GetBlockInfo(ct);
+      const auto ct_pos = ct_info->pos;
+      // Pos(H) <= Pos(CT(H)), with equality only for single-block loops.
+      if (header_info->is_single_block_loop && ct_pos != header_pos) {
+        Fail() << "Internal error: Single block loop.  CT pos is not the "
+                  "header pos. Should have already checked this";
+      }
+      if (!header_info->is_single_block_loop && (ct_pos <= header_pos)) {
+        Fail() << "Loop header " << header
+               << " does not dominate its continue target " << ct;
+      }
+        // Pos(CT(H)) < Pos(M(H))
+        // Note: When recording merges we made sure CT(H) != M(H)
+        if (merge_pos <= ct_pos) {
+          return Fail() << "Merge block " << merge
+                        << " for loop headed at block " << header
+                        << " appears at or before the loop's continue "
+                           "construct headed by "
+                           "block "
+                        << ct;
+        }
+  }
+  return success();
+}
+
 bool FunctionEmitter::EmitFunctionVariables() {
   if (failed()) {
     return false;
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index 46f15d7..349a30b 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -16,6 +16,7 @@
 #define SRC_READER_SPIRV_FUNCTION_H_
 
 #include <memory>
+#include <ostream>
 #include <unordered_map>
 #include <unordered_set>
 #include <vector>
@@ -67,6 +68,17 @@
   bool is_single_block_loop = false;
 };
 
+inline std::ostream& operator<<(std::ostream& o, const BlockInfo& bi) {
+  o << "BlockInfo{"
+    << " id: " << bi.id << " pos: " << bi.pos
+    << " merge_for_header: " << bi.merge_for_header
+    << " continue_for_header: " << bi.continue_for_header
+    << " header_for_merge: " << bi.header_for_merge
+    << " header_for_merge: " << bi.header_for_merge
+    << " single_block_loop: " << int(bi.is_single_block_loop) << "}";
+  return o;
+}
+
 /// A FunctionEmitter emits a SPIR-V function onto a Tint AST module.
 class FunctionEmitter {
  public:
@@ -129,6 +141,12 @@
   /// the function.
   const std::vector<uint32_t>& block_order() const { return block_order_; }
 
+  /// Verifies that the orderings among a structured header, continue target,
+  /// and merge block are valid. Assumes block order has been computed, and
+  /// merges are valid and recorded.
+  /// @returns false if invalid nesting was detected
+  bool VerifyHeaderContinueMergeOrder();
+
   /// Emits declarations of function variables.
   /// @returns false if emission failed.
   bool EmitFunctionVariables();
diff --git a/src/reader/spirv/function_cfg_test.cc b/src/reader/spirv/function_cfg_test.cc
index d3f1555..9a42247 100644
--- a/src/reader/spirv/function_cfg_test.cc
+++ b/src/reader/spirv/function_cfg_test.cc
@@ -12,6 +12,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+#include <sstream>
 #include <string>
 #include <vector>
 
@@ -26,6 +27,16 @@
 namespace spirv {
 namespace {
 
+std::string Dump(const std::vector<uint32_t>& v) {
+  std::ostringstream o;
+  o << "{";
+  for (auto a : v) {
+    o << a << " ";
+  }
+  o << "}";
+  return o.str();
+}
+
 using ::testing::ElementsAre;
 using ::testing::Eq;
 
@@ -2548,6 +2559,198 @@
               ElementsAre(10, 20, 30, 35, 37, 40, 49, 50, 99));
 }
 
+TEST_F(SpvParserTest, VerifyHeaderContinueMergeOrder_Selection_Good) {
+  auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpSelectionMerge %99 None
+     OpBranchConditional %cond %20 %30
+
+     %20 = OpLabel
+     OpBranch %99
+
+     %30 = OpLabel
+     OpBranch %99
+
+     %99 = OpLabel
+     OpReturn
+
+     OpFunctionEnd
+)";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  fe.RegisterBasicBlocks();
+  fe.ComputeBlockOrderAndPositions();
+  fe.RegisterMerges();
+  EXPECT_TRUE(fe.VerifyHeaderContinueMergeOrder());
+}
+
+TEST_F(SpvParserTest, VerifyHeaderContinueMergeOrder_SingleBlockLoop_Good) {
+  auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpBranch %20
+
+     %20 = OpLabel
+     OpLoopMerge %99 %20 None
+     OpBranchConditional %cond %20 %99
+
+     %99 = OpLabel
+     OpReturn
+
+     OpFunctionEnd
+)";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  fe.RegisterBasicBlocks();
+  fe.ComputeBlockOrderAndPositions();
+  fe.RegisterMerges();
+  EXPECT_TRUE(fe.VerifyHeaderContinueMergeOrder()) << p->error();
+}
+
+TEST_F(SpvParserTest, VerifyHeaderContinueMergeOrder_MultiBlockLoop_Good) {
+  auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpBranch %20
+
+     %20 = OpLabel
+     OpLoopMerge %99 %30 None
+     OpBranchConditional %cond %30 %99
+
+     %30 = OpLabel
+     OpBranch %20
+
+     %99 = OpLabel
+     OpReturn
+
+     OpFunctionEnd
+)";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  fe.RegisterBasicBlocks();
+  fe.ComputeBlockOrderAndPositions();
+  fe.RegisterMerges();
+  EXPECT_TRUE(fe.VerifyHeaderContinueMergeOrder());
+}
+
+TEST_F(SpvParserTest,
+       VerifyHeaderContinueMergeOrder_HeaderDoesNotStrictlyDominateMerge) {
+  auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpBranch %20
+
+     %20 = OpLabel
+     OpBranch %50
+
+     %50 = OpLabel
+     OpSelectionMerge %20 None ; this is backward
+     OpBranchConditional %cond2 %60 %99
+
+     %60 = OpLabel
+     OpBranch %99
+
+     %99 = OpLabel
+     OpReturn
+
+     OpFunctionEnd
+)";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  fe.RegisterBasicBlocks();
+  fe.ComputeBlockOrderAndPositions();
+  fe.RegisterMerges();
+  EXPECT_FALSE(fe.VerifyHeaderContinueMergeOrder());
+  EXPECT_THAT(p->error(),
+              Eq("Header 50 does not strictly dominate its merge block 20"))
+      << *fe.GetBlockInfo(50) << std::endl
+      << *fe.GetBlockInfo(20) << std::endl
+      << Dump(fe.block_order());
+}
+
+TEST_F(
+    SpvParserTest,
+    VerifyHeaderContinueMergeOrder_HeaderDoesNotStrictlyDominateContinueTarget) {
+  auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpBranch %20
+
+     %20 = OpLabel
+     OpBranch %50
+
+     %50 = OpLabel
+     OpLoopMerge %99 %20 None ; this is backward
+     OpBranchConditional %cond %60 %99
+
+     %60 = OpLabel
+     OpBranch %50
+
+     %99 = OpLabel
+     OpReturn
+
+     OpFunctionEnd
+)";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  fe.RegisterBasicBlocks();
+  fe.ComputeBlockOrderAndPositions();
+  fe.RegisterMerges();
+  EXPECT_FALSE(fe.VerifyHeaderContinueMergeOrder());
+  EXPECT_THAT(p->error(),
+              Eq("Loop header 50 does not dominate its continue target 20"))
+      << *fe.GetBlockInfo(50) << std::endl
+      << *fe.GetBlockInfo(20) << std::endl
+      << Dump(fe.block_order());
+}
+
+TEST_F(SpvParserTest,
+       VerifyHeaderContinueMergeOrder_MergeInsideContinueTarget) {
+  auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpBranch %50
+
+     %50 = OpLabel
+     OpLoopMerge %60 %70 None
+     OpBranchConditional %cond %60 %99
+
+     %60 = OpLabel
+     OpBranch %70
+
+     %70 = OpLabel
+     OpBranch %50
+
+     %99 = OpLabel
+     OpReturn
+
+     OpFunctionEnd
+)";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  fe.RegisterBasicBlocks();
+  fe.ComputeBlockOrderAndPositions();
+  fe.RegisterMerges();
+  EXPECT_FALSE(fe.VerifyHeaderContinueMergeOrder());
+  EXPECT_THAT(p->error(),
+              Eq("Merge block 60 for loop headed at block 50 appears at or "
+                 "before the loop's continue construct headed by block 70"))
+      << Dump(fe.block_order());
+}
+
 }  // namespace
 }  // namespace spirv
 }  // namespace reader