[spirv-reader] Improve placement of hoisted vars

When we hoist a variable out of a continue construct, put it
in associated loop construct, if it exists.  This reduces its
lifetime in WGSL, and easier to understand as a code reader.

Change-Id: I8f0cc37640bfe67874cbc27b55029e79e9a8992c
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/24321
Reviewed-by: dan sinclair <dsinclair@google.com>
diff --git a/src/reader/spirv/construct.h b/src/reader/spirv/construct.h
index 767d5de..5e6f754 100644
--- a/src/reader/spirv/construct.h
+++ b/src/reader/spirv/construct.h
@@ -218,6 +218,13 @@
   return ss.str();
 }
 
+/// Converts a construct to a string.
+/// @param c the construct
+/// @returns the string representation
+inline std::string ToString(const Construct* c) {
+  return c ? ToString(*c) : ToStringBrief(c);
+}
+
 /// Converts a unique pointer to a construct to a string.
 /// @param c the construct
 /// @returns the string representation
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index dad106c..df4dc52 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -1145,6 +1145,30 @@
   return nullptr;
 }
 
+const Construct* FunctionEmitter::SiblingLoopConstruct(
+    const Construct* c) const {
+  if (c == nullptr || c->kind != Construct::kContinue) {
+    return nullptr;
+  }
+  const uint32_t continue_target_id = c->begin_id;
+  const auto* continue_target = GetBlockInfo(continue_target_id);
+  const uint32_t header_id = continue_target->header_for_continue;
+  if (continue_target_id == header_id) {
+    // The continue target is the whole loop.
+    return nullptr;
+  }
+  const auto* candidate = GetBlockInfo(header_id)->construct;
+  // Walk up the construct tree until we hit the loop.  In future
+  // we might handle the corner case where the same block is both a
+  // loop header and a selection header. For example, where the
+  // loop header block has a conditional branch going to distinct
+  // targets inside the loop body.
+  while (candidate && candidate->kind != Construct::kLoop) {
+    candidate = candidate->parent;
+  }
+  return candidate;
+}
+
 bool FunctionEmitter::ClassifyCFGEdges() {
   if (failed()) {
     return false;
@@ -3013,13 +3037,17 @@
     const auto block_pos = block_info->pos;
     for (const auto& inst : *(block_info->basic_block)) {
       // Update the usage span for IDs used by this instruction.
-      inst.ForEachInId([this, block_pos](const uint32_t* id_ptr) {
-        auto* def_info = GetDefInfo(*id_ptr);
-        if (def_info) {
-          def_info->num_uses++;
-          def_info->last_use_pos = std::max(def_info->last_use_pos, block_pos);
-        }
-      });
+      // But skip uses in OpPhi because they are handled differently.
+      if (inst.opcode() != SpvOpPhi) {
+        inst.ForEachInId([this, block_pos](const uint32_t* id_ptr) {
+          auto* def_info = GetDefInfo(*id_ptr);
+          if (def_info) {
+            def_info->num_uses++;
+            def_info->last_use_pos =
+                std::max(def_info->last_use_pos, block_pos);
+          }
+        });
+      }
 
       if (inst.opcode() == SpvOpPhi) {
         // Declare a name for the variable used to carry values to a phi.
@@ -3051,7 +3079,7 @@
 
         // Schedule the declaration of the state variable.
         const auto* enclosing_construct =
-            GetSmallestEnclosingConstruct(first_pos, last_pos);
+            GetEnclosingScope(first_pos, last_pos);
         GetBlockInfo(enclosing_construct->begin_id)
             ->phis_needing_state_vars.push_back(phi_id);
       }
@@ -3088,7 +3116,7 @@
 
     if (def_in_construct != construct_with_last_use) {
       const auto* enclosing_construct =
-          GetSmallestEnclosingConstruct(first_pos, last_use_pos);
+          GetEnclosingScope(first_pos, last_use_pos);
       if (enclosing_construct == def_in_construct) {
         // We can use a plain 'const' definition.
         def_info->requires_named_const_def = true;
@@ -3103,15 +3131,19 @@
   }
 }
 
-const Construct* FunctionEmitter::GetSmallestEnclosingConstruct(
-    uint32_t first_pos,
-    uint32_t last_pos) const {
+const Construct* FunctionEmitter::GetEnclosingScope(uint32_t first_pos,
+                                                    uint32_t last_pos) const {
   const auto* enclosing_construct =
       GetBlockInfo(block_order_[first_pos])->construct;
   assert(enclosing_construct != nullptr);
   // Constructs are strictly nesting, so follow parent pointers
   while (enclosing_construct && !enclosing_construct->ContainsPos(last_pos)) {
-    enclosing_construct = enclosing_construct->parent;
+    // The scope of a continue construct is enclosed in its associated loop
+    // construct, but they are siblings in our construct tree.
+    const auto* sibling_loop = SiblingLoopConstruct(enclosing_construct);
+    // Go to the sibling loop if it exists, otherwise walk up to the parent.
+    enclosing_construct =
+        sibling_loop ? sibling_loop : enclosing_construct->parent;
   }
   // At worst, we go all the way out to the function construct.
   assert(enclosing_construct != nullptr);
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index 7044b0f..606bcde 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -601,14 +601,23 @@
     return where->second.get();
   }
 
-  /// Returns the most deeply nested structured construct which encloses
-  /// both block positions. Each position must be a valid index into the
-  /// function block order array.
+  /// Returns the most deeply nested structured construct which encloses the
+  /// WGSL scopes of names declared in both block positions. Each position must
+  /// be a valid index into the function block order array.
   /// @param first_pos the first block position
   /// @param last_pos the last block position
   /// @returns the smallest construct containing both positions
-  const Construct* GetSmallestEnclosingConstruct(uint32_t first_pos,
-                                                 uint32_t last_pos) const;
+  const Construct* GetEnclosingScope(uint32_t first_pos,
+                                     uint32_t last_pos) const;
+
+  /// Finds loop construct associated with a continue construct, if it exists.
+  /// Returns nullptr if:
+  ///  - the given construct is not a continue construct
+  ///  - the continue construct does not have an associated loop construct
+  ///    (the continue target is also the loop header block)
+  /// @param c the continue construct
+  /// @returns the associated loop construct, or nullptr
+  const Construct* SiblingLoopConstruct(const Construct* c) const;
 
  private:
   /// @returns the store type for the OpVariable instruction, or
diff --git a/src/reader/spirv/function_cfg_test.cc b/src/reader/spirv/function_cfg_test.cc
index d245427..188b3b9 100644
--- a/src/reader/spirv/function_cfg_test.cc
+++ b/src/reader/spirv/function_cfg_test.cc
@@ -17,6 +17,7 @@
 #include <vector>
 
 #include "gmock/gmock.h"
+#include "src/reader/spirv/construct.h"
 #include "src/reader/spirv/function.h"
 #include "src/reader/spirv/parser_impl.h"
 #include "src/reader/spirv/parser_impl_test_helper.h"
@@ -87,15 +88,22 @@
   )";
 }
 
-/// Runs the necessary flow until and including finding switch case
-/// headers.
-/// @returns the result of finding switch case headers.
-bool FlowFindSwitchCaseHeaders(FunctionEmitter* fe) {
+/// Runs the necessary flow until and including labeling control
+/// flow constructs.
+/// @returns the result of labeling control flow constructs.
+bool FlowLabelControlFlowConstructs(FunctionEmitter* fe) {
   fe->RegisterBasicBlocks();
   EXPECT_TRUE(fe->RegisterMerges()) << fe->parser()->error();
   fe->ComputeBlockOrderAndPositions();
   EXPECT_TRUE(fe->VerifyHeaderContinueMergeOrder()) << fe->parser()->error();
-  EXPECT_TRUE(fe->LabelControlFlowConstructs()) << fe->parser()->error();
+  return fe->LabelControlFlowConstructs();
+}
+
+/// Runs the necessary flow until and including finding switch case
+/// headers.
+/// @returns the result of finding switch case headers.
+bool FlowFindSwitchCaseHeaders(FunctionEmitter* fe) {
+  EXPECT_TRUE(FlowLabelControlFlowConstructs(fe)) << fe->parser()->error();
   return fe->FindSwitchCaseHeaders();
 }
 
@@ -13418,6 +13426,121 @@
       << ToString(fe.ast_body());
 }
 
+TEST_F(SpvParserTest, SiblingLoopConstruct_Null) {
+  auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+     %10 = OpLabel
+     OpReturn
+     OpFunctionEnd
+)";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_EQ(fe.SiblingLoopConstruct(nullptr), nullptr);
+}
+
+TEST_F(SpvParserTest, SiblingLoopConstruct_NotAContinue) {
+  auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpReturn
+
+     OpFunctionEnd
+)";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  ASSERT_TRUE(FlowLabelControlFlowConstructs(&fe)) << p->error();
+  const Construct* c = fe.GetBlockInfo(10)->construct;
+  EXPECT_NE(c, nullptr);
+  EXPECT_EQ(fe.SiblingLoopConstruct(c), nullptr);
+}
+
+TEST_F(SpvParserTest, SiblingLoopConstruct_SingleBlockLoop) {
+  auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpBranch %20
+
+     %20 = OpLabel
+     OpLoopMerge %99 %20 None
+     OpBranchConditional %cond %20 %99
+
+     %99 = OpLabel
+     OpReturn
+
+     OpFunctionEnd
+)";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  ASSERT_TRUE(FlowLabelControlFlowConstructs(&fe)) << p->error();
+  const Construct* c = fe.GetBlockInfo(20)->construct;
+  EXPECT_EQ(c->kind, Construct::kContinue);
+  EXPECT_EQ(fe.SiblingLoopConstruct(c), nullptr);
+}
+
+TEST_F(SpvParserTest, SiblingLoopConstruct_ContinueIsWholeMultiBlockLoop) {
+  auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpBranch %20
+
+     %20 = OpLabel
+     OpLoopMerge %99 %20 None ; continue target is also loop header
+     OpBranchConditional %cond %30 %99
+
+     %30 = OpLabel
+     OpBranch %20
+
+     %99 = OpLabel
+     OpReturn
+
+     OpFunctionEnd
+)";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
+      << p->error() << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  ASSERT_TRUE(FlowLabelControlFlowConstructs(&fe)) << p->error();
+  const Construct* c = fe.GetBlockInfo(20)->construct;
+  EXPECT_EQ(c->kind, Construct::kContinue);
+  EXPECT_EQ(fe.SiblingLoopConstruct(c), nullptr);
+}
+
+TEST_F(SpvParserTest, SiblingLoopConstruct_HasSiblingLoop) {
+  auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpBranch %20
+
+     %20 = OpLabel
+     OpLoopMerge %99 %30 None
+     OpBranchConditional %cond %30 %99
+
+     %30 = OpLabel ; continue target
+     OpBranch %20
+
+     %99 = OpLabel
+     OpReturn
+
+     OpFunctionEnd
+)";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  ASSERT_TRUE(FlowLabelControlFlowConstructs(&fe)) << p->error();
+  const Construct* c = fe.GetBlockInfo(30)->construct;
+  EXPECT_EQ(c->kind, Construct::kContinue);
+  EXPECT_THAT(ToString(fe.SiblingLoopConstruct(c)),
+              Eq("Construct{ Loop [1,2) begin_id:20 end_id:30 depth:1 "
+                 "parent:Function@10 in-l:Loop@20 }"));
+}
+
 }  // namespace
 }  // namespace spirv
 }  // namespace reader
diff --git a/src/reader/spirv/function_var_test.cc b/src/reader/spirv/function_var_test.cc
index 6693ae7..c591175 100644
--- a/src/reader/spirv/function_var_test.cc
+++ b/src/reader/spirv/function_var_test.cc
@@ -1042,13 +1042,23 @@
         }
       }
     }
+    VariableDeclStatement{
+      Variable{
+        x_4
+        none
+        __u32
+        {
+          Binary{
+            Identifier{x_2}
+            add
+            ScalarConstructor{1}
+          }
+        }
+      }
+    }
     Assignment{
       Identifier{x_2_phi}
-      Binary{
-        Identifier{x_2}
-        add
-        ScalarConstructor{1}
-      }
+      Identifier{x_4}
     }
     Assignment{
       Identifier{x_3_phi}
@@ -1122,13 +1132,6 @@
   }
   VariableDeclStatement{
     Variable{
-      x_4
-      function
-      __u32
-    }
-  }
-  VariableDeclStatement{
-    Variable{
       x_2_phi
       function
       __u32
@@ -1201,12 +1204,18 @@
       }
     }
     continuing {
-      Assignment{
-        Identifier{x_4}
-        Binary{
-          Identifier{x_2}
-          add
-          ScalarConstructor{1}
+      VariableDeclStatement{
+        Variable{
+          x_4
+          none
+          __u32
+          {
+            Binary{
+              Identifier{x_2}
+              add
+              ScalarConstructor{1}
+            }
+          }
         }
       }
       Assignment{
@@ -1224,6 +1233,174 @@
 )")) << ToString(fe.ast_body());
 }
 
+TEST_F(SpvParserTest, EmitStatement_Phi_ValueFromLoopBodyAndContinuing) {
+  auto assembly = Preamble() + R"(
+     %pty = OpTypePointer Private %uint
+     %1 = OpVariable %pty Private
+     %boolpty = OpTypePointer Private %bool
+     %17 = OpVariable %boolpty Private
+
+     %100 = OpFunction %void None %voidfn
+
+     %9 = OpLabel
+     %101 = OpLoad %bool %17
+     OpBranch %10
+
+     ; Use an outer loop to show we put the new variable in the
+     ; smallest enclosing scope.
+     %10 = OpLabel
+     OpLoopMerge %99 %89 None
+     OpBranch %20
+
+     %20 = OpLabel
+     %2 = OpPhi %uint %uint_0 %10 %4 %30  ; gets computed value
+     %5 = OpPhi %uint %uint_1 %10 %7 %30
+     %4 = OpIAdd %uint %2 %uint_1 ; define %4
+     %6 = OpIAdd %uint %4 %uint_1 ; use %4
+     OpLoopMerge %89 %30 None
+     OpBranchConditional %101 %89 %30
+
+     %30 = OpLabel
+     %7 = OpIAdd %uint %4 %6 ; use %4 again
+     OpBranch %20
+
+     %89 = OpLabel
+     OpBranch %10
+
+     %99 = OpLabel
+     OpReturn
+
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
+      << assembly << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+
+  EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(VariableDeclStatement{
+  Variable{
+    x_101
+    none
+    __bool
+    {
+      Identifier{x_17}
+    }
+  }
+}
+Loop{
+  VariableDeclStatement{
+    Variable{
+      x_4
+      function
+      __u32
+    }
+  }
+  VariableDeclStatement{
+    Variable{
+      x_6
+      function
+      __u32
+    }
+  }
+  VariableDeclStatement{
+    Variable{
+      x_2_phi
+      function
+      __u32
+    }
+  }
+  VariableDeclStatement{
+    Variable{
+      x_5_phi
+      function
+      __u32
+    }
+  }
+  Assignment{
+    Identifier{x_2_phi}
+    ScalarConstructor{0}
+  }
+  Assignment{
+    Identifier{x_5_phi}
+    ScalarConstructor{1}
+  }
+  Loop{
+    VariableDeclStatement{
+      Variable{
+        x_2
+        none
+        __u32
+        {
+          Identifier{x_2_phi}
+        }
+      }
+    }
+    VariableDeclStatement{
+      Variable{
+        x_5
+        none
+        __u32
+        {
+          Identifier{x_5_phi}
+        }
+      }
+    }
+    Assignment{
+      Identifier{x_4}
+      Binary{
+        Identifier{x_2}
+        add
+        ScalarConstructor{1}
+      }
+    }
+    Assignment{
+      Identifier{x_6}
+      Binary{
+        Identifier{x_4}
+        add
+        ScalarConstructor{1}
+      }
+    }
+    If{
+      (
+        Identifier{x_101}
+      )
+      {
+        Break{}
+      }
+    }
+    continuing {
+      VariableDeclStatement{
+        Variable{
+          x_7
+          none
+          __u32
+          {
+            Binary{
+              Identifier{x_4}
+              add
+              Identifier{x_6}
+            }
+          }
+        }
+      }
+      Assignment{
+        Identifier{x_2_phi}
+        Identifier{x_4}
+      }
+      Assignment{
+        Identifier{x_5_phi}
+        Identifier{x_7}
+      }
+    }
+  }
+}
+Return{}
+)")) << ToString(fe.ast_body())
+     << assembly;
+}
+
 TEST_F(SpvParserTest, EmitStatement_Phi_FromElseAndThen) {
   auto assembly = Preamble() + R"(
      %pty = OpTypePointer Private %uint