[spirv-reader] Hoist definitions as needed

Compensate for the fact that dominance does not correspond
exactly to scoping. A definition can dominate a use, but when mapped
in a naive way to constant definitiion and its use, the definition
name goes out of scope by the time you reach the use.

This is correct for storable types.

Bug: tint:3
Change-Id: I03e6c5ba68393151485ed4cdbe6b2b3d7773d1ad
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/24141
Reviewed-by: dan sinclair <dsinclair@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index ad68e73..8b89d80 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -624,7 +624,7 @@
 
   // TODO(dneto): register phis
   // TODO(dneto): register SSA values which need to be hoisted
-  RegisterValuesNeedingNamedDefinition();
+  RegisterValuesNeedingNamedOrHoistedDefinition();
 
   if (!EmitFunctionVariables()) {
     return false;
@@ -2361,6 +2361,19 @@
     // Only emit this part of the basic block once.
     return true;
   }
+
+  // Emit declarations of hoisted variables.
+  for (auto id : block_info.hoisted_ids) {
+    const auto* def_inst = def_use_mgr_->GetDef(id);
+    assert(def_inst);
+    AddStatement(
+        std::make_unique<ast::VariableDeclStatement>(parser_impl_.MakeVariable(
+            id, ast::StorageClass::kFunction,
+            parser_impl_.ConvertType(def_inst->type_id()))));
+    // Save this as an already-named value.
+    identifier_values_.insert(id);
+  }
+
   const spvtools::opt::BasicBlock& bb = *(block_info.basic_block);
   const auto* terminator = bb.terminator();
   const auto* merge = bb.GetMergeInst();  // Might be nullptr
@@ -2399,22 +2412,38 @@
   return success();
 }
 
+bool FunctionEmitter::EmitConstDefOrWriteToHoistedVar(
+    const spvtools::opt::Instruction& inst,
+    TypedExpression ast_expr) {
+  const auto result_id = inst.result_id();
+  if (needs_hoisted_def_.count(result_id) != 0) {
+    // Emit an assignment of the expression to the hoisted variable.
+    AddStatement(std::make_unique<ast::AssignmentStatement>(
+        std::make_unique<ast::IdentifierExpression>(namer_.Name(result_id)),
+        std::move(ast_expr.expr)));
+    return true;
+  }
+  return EmitConstDefinition(inst, std::move(ast_expr));
+}
+
 bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) {
-  // Handle combinatorial instructions first.
+  const auto result_id = inst.result_id();
+  // Handle combinatorial instructions.
   auto combinatorial_expr = MaybeEmitCombinatorialValue(inst);
   if (combinatorial_expr.expr != nullptr) {
-    if ((needs_named_const_def_.count(inst.result_id()) == 0) &&
+    if ((needs_hoisted_def_.count(result_id) == 0) &&
+        (needs_named_const_def_.count(result_id) == 0) &&
         (def_use_mgr_->NumUses(&inst) == 1)) {
       // If it's used once, and doesn't need a named constant definition,
       // then defer emitting the expression until it's used. Any supporting
       // statements have already been emitted.
       singly_used_values_.insert(
-          std::make_pair(inst.result_id(), std::move(combinatorial_expr)));
+          std::make_pair(result_id, std::move(combinatorial_expr)));
       return success();
     }
     // Otherwise, generate a const definition for it now and later use
     // the const's name at the uses of the value.
-    return EmitConstDefinition(inst, std::move(combinatorial_expr));
+    return EmitConstDefOrWriteToHoistedVar(inst, std::move(combinatorial_expr));
   }
   if (failed()) {
     return false;
@@ -2435,13 +2464,13 @@
     case SpvOpLoad:
       // Memory accesses must be issued in SPIR-V program order.
       // So represent a load by a new const definition.
-      return EmitConstDefinition(
+      return EmitConstDefOrWriteToHoistedVar(
           inst, MakeExpression(inst.GetSingleWordInOperand(0)));
     case SpvOpCopyObject:
       // Arguably, OpCopyObject is purely combinatorial. On the other hand,
       // it exists to make a new name for something. So we choose to make
       // a new named constant definition.
-      return EmitConstDefinition(
+      return EmitConstDefOrWriteToHoistedVar(
           inst, MakeExpression(inst.GetSingleWordInOperand(0)));
     case SpvOpFunctionCall:
       // TODO(dneto): Fill this out.  Make this pass, for existing tests
@@ -2876,7 +2905,7 @@
                            result_type, std::move(values))};
 }
 
-void FunctionEmitter::RegisterValuesNeedingNamedDefinition() {
+void FunctionEmitter::RegisterValuesNeedingNamedOrHoistedDefinition() {
   // Maps a result ID to the block position where it is last used.
   std::unordered_map<uint32_t, uint32_t> id_to_last_use_pos;
   // List of pairs of (result id, block position of the definition).
@@ -2930,12 +2959,32 @@
     auto last_use_where = id_to_last_use_pos.find(id);
     if (last_use_where != id_to_last_use_pos.end()) {
       const auto last_use_pos = last_use_where->second;
-      const auto* def_in_construct =
+      const auto* const def_in_construct =
           GetBlockInfo(block_order_[def_pos])->construct;
-      const auto* last_use_in_construct =
+      const auto* const construct_with_last_use =
           GetBlockInfo(block_order_[last_use_pos])->construct;
-      if (def_in_construct != last_use_in_construct) {
-        needs_named_const_def_.insert(id);
+
+      // Find the smallest structured construct that encloses the definition
+      // and all its uses.
+      const auto* enclosing_construct = def_in_construct;
+      while (enclosing_construct &&
+             !enclosing_construct->ContainsPos(last_use_pos)) {
+        enclosing_construct = enclosing_construct->parent;
+      }
+      // At worst, we go all the way out to the function construct.
+      assert(enclosing_construct != nullptr);
+
+      if (def_in_construct != construct_with_last_use) {
+        if (enclosing_construct == def_in_construct) {
+          // We can use a plain 'const' definition.
+          needs_named_const_def_.insert(id);
+        } else {
+          // We need to make a hoisted variable definition.
+          // TODO(dneto): Handle non-storable types, particularly pointers.
+          needs_hoisted_def_.insert(id);
+          auto* hoist_to_block = GetBlockInfo(enclosing_construct->begin_id);
+          hoist_to_block->hoisted_ids.push_back(id);
+        }
       }
     }
   }
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index 9a8ebad..484df30 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -157,6 +157,11 @@
   /// This occurs when a block in this selection has both an if-break edge, and
   /// also a different normal forward edge but without a merge instruction.
   std::string flow_guard_name = "";
+
+  /// The result IDs that this block is responsible for declaring as a
+  /// hoisted variable.  See the |needs_hoisted_def_| member of
+  /// FunctionEmitter for an explanation.
+  std::vector<uint32_t> hoisted_ids;
 };
 
 inline std::ostream& operator<<(std::ostream& o, const BlockInfo& bi) {
@@ -278,11 +283,18 @@
   bool FindIfSelectionInternalHeaders();
 
   /// Record the SPIR-V IDs of non-constants that should get a 'const'
-  /// definition in WGSL. This occurs when a SPIR-V instruction might use the
-  /// dynamically computed value only once, but the WGSL code might reference
-  /// it multiple times. For example, this occurs for the vector operands of
-  /// OpVectorShuffle.  Populates |needs_named_const_def_|
-  void RegisterValuesNeedingNamedDefinition();
+  /// definition in WGSL, or a 'var' definition at an outer scope.
+  /// This occurs in several cases:
+  ///  - When a SPIR-V instruction might use the dynamically computed value
+  ///    only once, but the WGSL code might reference it multiple times.
+  ///    For example, this occurs for the vector operands of OpVectorShuffle.
+  ///    In this case the definition is added to |needs_named_const_def_|.
+  ///  - When a definition and at least one of its uses are not in the
+  ///    same structured construct.
+  ///    In this case the definition is added to |needs_named_const_def_|.
+  ///  - When a definition is in a construct that does not enclose all the
+  ///    uses.  In this case the definition is added to |needs_hoisted_def_|.
+  void RegisterValuesNeedingNamedOrHoistedDefinition();
 
   /// Emits declarations of function variables.
   /// @returns false if emission failed.
@@ -431,6 +443,15 @@
   bool EmitConstDefinition(const spvtools::opt::Instruction& inst,
                            TypedExpression ast_expr);
 
+  /// Emits a write to a hoisted variable for the given SPIR-V id,
+  /// if that ID has a hoisted declaration. Otherwise, emits a const
+  /// definition instead.
+  /// @param inst the SPIR-V instruction defining the value
+  /// @param ast_expr the already-computed AST expression for the value
+  /// @returns false if emission failed.
+  bool EmitConstDefOrWriteToHoistedVar(const spvtools::opt::Instruction& inst,
+                                       TypedExpression ast_expr);
+
   /// Makes an expression
   /// @param id the SPIR-V ID of the value
   /// @returns true if emission has not yet failed.
@@ -603,6 +624,19 @@
   std::unordered_map<uint32_t, TypedExpression> singly_used_values_;
   // Set of SPIR-V IDs which should get a named const definition.
   std::unordered_set<uint32_t> needs_named_const_def_;
+  // The SPIR-V IDs that must be declared in WGSL before the corresponding
+  // location in SPIR-V. This compensates for the difference between dominance
+  // and scoping. An SSA definition can dominate all its uses, but the construct
+  // where it is defined does not enclose all the uses, and so if it were
+  // declared as a WGSL constant definition at the point of its SPIR-V
+  // definition, then the WGSL name would go out of scope too early. Fix that by
+  // creating a variable at the top of the smallest construct that encloses both
+  // the definition and all its uses. Then the original SPIR-V definition maps
+  // to a WGSL assignment to that variable, and each SPIR-V use becomes a WGSL
+  // read from the variable.
+  // TODO(dneto): This works for constants of storable type, but not, for
+  // example, pointers.
+  std::unordered_set<uint32_t> needs_hoisted_def_;
 
   // The IDs of basic blocks, in reverse structured post-order (RSPO).
   // This is the output order for the basic blocks.
diff --git a/src/reader/spirv/function_var_test.cc b/src/reader/spirv/function_var_test.cc
index ca8e217..85ac938 100644
--- a/src/reader/spirv/function_var_test.cc
+++ b/src/reader/spirv/function_var_test.cc
@@ -61,9 +61,13 @@
     %false = OpConstantFalse %bool
     %float_0 = OpConstant %float 0.0
     %float_1p5 = OpConstant %float 1.5
+    %uint_0 = OpConstant %uint 0
     %uint_1 = OpConstant %uint 1
     %int_m1 = OpConstant %int -1
     %uint_2 = OpConstant %uint 2
+    %uint_3 = OpConstant %uint 3
+    %uint_4 = OpConstant %uint 4
+    %uint_5 = OpConstant %uint 5
 
     %v2float = OpTypeVector %float 2
     %m3v2float = OpTypeMatrix %v2float 3
@@ -794,6 +798,137 @@
 )")) << ToString(fe.ast_body());
 }
 
+TEST_F(
+    SpvParserTest,
+    EmitStatement_CombinatorialNonPointer_DefConstruct_DoesNotEncloseAllUses) {
+  // Compensate for the difference between dominance and scoping.
+  // Exercise hoisting of the constant definition to before its natural
+  // location.
+  //
+  // The definition of %2 should be hoisted
+  auto assembly = Preamble() + R"(
+     %pty = OpTypePointer Private %uint
+     %1 = OpVariable %pty Private
+
+     %100 = OpFunction %void None %voidfn
+
+     %3 = OpLabel
+     OpStore %1 %uint_0
+     OpBranch %5
+
+     %5 = OpLabel
+     OpStore %1 %uint_1
+     OpLoopMerge  %99 %80 None
+     OpBranchConditional %false %99 %20
+
+     %20 = OpLabel
+     OpStore %1 %uint_3
+     OpSelectionMerge %50 None
+     OpBranchConditional %true %30 %40
+
+     %30 = OpLabel
+     ; This combinatorial definition in nested control flow dominates
+     ; the use in the merge block in %50
+     %2 = OpIAdd %uint %uint_1 %uint_1
+     OpBranch %50
+
+     %40 = OpLabel
+     OpReturn
+
+     %50 = OpLabel ; merge block for if-selection
+     OpStore %1 %2
+     OpBranch %80
+
+     %80 = OpLabel ; merge block
+     OpStore %1 %uint_4
+     OpBranchConditional %false %99 %5 ; loop backedge
+
+     %99 = OpLabel
+     OpStore %1 %uint_5
+     OpReturn
+
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+
+  EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{
+  Identifier{x_1}
+  ScalarConstructor{0}
+}
+Loop{
+  VariableDeclStatement{
+    Variable{
+      x_2
+      function
+      __u32
+    }
+  }
+  Assignment{
+    Identifier{x_1}
+    ScalarConstructor{1}
+  }
+  If{
+    (
+      ScalarConstructor{false}
+    )
+    {
+      Break{}
+    }
+  }
+  Assignment{
+    Identifier{x_1}
+    ScalarConstructor{3}
+  }
+  If{
+    (
+      ScalarConstructor{true}
+    )
+    {
+      Assignment{
+        Identifier{x_2}
+        Binary{
+          ScalarConstructor{1}
+          add
+          ScalarConstructor{1}
+        }
+      }
+    }
+  }
+  Else{
+    {
+      Return{}
+    }
+  }
+  Assignment{
+    Identifier{x_1}
+    Identifier{x_2}
+  }
+  continuing {
+    Assignment{
+      Identifier{x_1}
+      ScalarConstructor{4}
+    }
+    If{
+      (
+        ScalarConstructor{false}
+      )
+      {
+        Break{}
+      }
+    }
+  }
+}
+Assignment{
+  Identifier{x_1}
+  ScalarConstructor{5}
+}
+Return{}
+)")) << ToString(fe.ast_body());
+}
+
 }  // namespace
 }  // namespace spirv
 }  // namespace reader