[spirv-reader] Don't move combinatorial values across control flow

Avoid sinking expensive operations into control flow such as loops.
The heuristic way to achieve that is to avoid moving combinatorial
values across *any* structured construct boundaries.

Bug: tint:3
Change-Id: I91502b01166a0db64c0e652331591850df75f9d4
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/24140
Reviewed-by: dan sinclair <dsinclair@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 840836b..ad68e73 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -2877,8 +2877,30 @@
 }
 
 void FunctionEmitter::RegisterValuesNeedingNamedDefinition() {
-  for (auto& block : function_) {
-    for (const auto& inst : block) {
+  // Maps a result ID to the block position where it is last used.
+  std::unordered_map<uint32_t, uint32_t> id_to_last_use_pos;
+  // List of pairs of (result id, block position of the definition).
+  std::vector<std::pair<uint32_t, uint32_t>> id_def_pos;
+
+  for (auto block_id : block_order_) {
+    const auto* block_info = GetBlockInfo(block_id);
+    const auto block_pos = block_info->pos;
+
+    for (const auto& inst : *(block_info->basic_block)) {
+      const auto result_id = inst.result_id();
+      if (result_id != 0) {
+        id_def_pos.emplace_back(
+            std::pair<uint32_t, uint32_t>{result_id, block_pos});
+      }
+      inst.ForEachInId(
+          [&id_to_last_use_pos, block_pos](const uint32_t* id_ptr) {
+            // If the id is not in the map already, this will create
+            // an entry with value 0.
+            auto& pos = id_to_last_use_pos[*id_ptr];
+            // Update the entry.
+            pos = std::max(pos, block_pos);
+          });
+
       if (inst.opcode() == SpvOpVectorShuffle) {
         // We might access the vector operands multiple times. Make sure they
         // are evaluated only once.
@@ -2896,6 +2918,27 @@
       }
     }
   }
+
+  // For an ID defined in this function, if it is used in a different construct
+  // than its definition, then it needs a named constant definition.  Otherwise
+  // we might sink an expensive computation into control flow, and hence change
+  // performance.
+  for (const auto& id_and_pos : id_def_pos) {
+    const auto id = id_and_pos.first;
+    const auto def_pos = id_and_pos.second;
+
+    auto last_use_where = id_to_last_use_pos.find(id);
+    if (last_use_where != id_to_last_use_pos.end()) {
+      const auto last_use_pos = last_use_where->second;
+      const auto* def_in_construct =
+          GetBlockInfo(block_order_[def_pos])->construct;
+      const auto* last_use_in_construct =
+          GetBlockInfo(block_order_[last_use_pos])->construct;
+      if (def_in_construct != last_use_in_construct) {
+        needs_named_const_def_.insert(id);
+      }
+    }
+  }
 }
 
 TypedExpression FunctionEmitter::MakeNumericConversion(
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index fbab63b..9a8ebad 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -112,7 +112,8 @@
   /// as its own continue target, and has branch to itself.
   bool is_single_block_loop = false;
 
-  /// The immediately enclosing structured construct.
+  /// The immediately enclosing structured construct. If this block is not
+  /// in the block order at all, then this is still nullptr.
   const Construct* construct = nullptr;
 
   /// Maps the ID of a successor block (in the CFG) to its edge classification.
diff --git a/src/reader/spirv/function_var_test.cc b/src/reader/spirv/function_var_test.cc
index dd9f283..ca8e217 100644
--- a/src/reader/spirv/function_var_test.cc
+++ b/src/reader/spirv/function_var_test.cc
@@ -26,6 +26,7 @@
 namespace spirv {
 namespace {
 
+using ::testing::Eq;
 using ::testing::HasSubstr;
 
 /// @returns a SPIR-V assembly segment which assigns debug names
@@ -38,8 +39,11 @@
   return outs.str();
 }
 
-std::string CommonTypes() {
+std::string Preamble() {
   return R"(
+    OpCapability Shader
+    OpMemoryModel Logical Simple
+
     %void = OpTypeVoid
     %voidfn = OpTypeFunction %void
 
@@ -70,7 +74,7 @@
 }
 
 TEST_F(SpvParserTest, EmitFunctionVariables_AnonymousVars) {
-  auto* p = parser(test::Assemble(CommonTypes() + R"(
+  auto* p = parser(test::Assemble(Preamble() + R"(
      %100 = OpFunction %void None %voidfn
      %entry = OpLabel
      %1 = OpVariable %ptr_uint Function
@@ -108,7 +112,7 @@
 }
 
 TEST_F(SpvParserTest, EmitFunctionVariables_NamedVars) {
-  auto* p = parser(test::Assemble(Names({"a", "b", "c"}) + CommonTypes() + R"(
+  auto* p = parser(test::Assemble(Names({"a", "b", "c"}) + Preamble() + R"(
      %100 = OpFunction %void None %voidfn
      %entry = OpLabel
      %a = OpVariable %ptr_uint Function
@@ -146,7 +150,7 @@
 }
 
 TEST_F(SpvParserTest, EmitFunctionVariables_MixedTypes) {
-  auto* p = parser(test::Assemble(Names({"a", "b", "c"}) + CommonTypes() + R"(
+  auto* p = parser(test::Assemble(Names({"a", "b", "c"}) + Preamble() + R"(
      %100 = OpFunction %void None %voidfn
      %entry = OpLabel
      %a = OpVariable %ptr_uint Function
@@ -184,8 +188,8 @@
 }
 
 TEST_F(SpvParserTest, EmitFunctionVariables_ScalarInitializers) {
-  auto* p = parser(
-      test::Assemble(Names({"a", "b", "c", "d", "e"}) + CommonTypes() + R"(
+  auto* p =
+      parser(test::Assemble(Names({"a", "b", "c", "d", "e"}) + Preamble() + R"(
      %100 = OpFunction %void None %voidfn
      %entry = OpLabel
      %a = OpVariable %ptr_bool Function %true
@@ -254,8 +258,7 @@
 }
 
 TEST_F(SpvParserTest, EmitFunctionVariables_ScalarNullInitializers) {
-  auto* p =
-      parser(test::Assemble(Names({"a", "b", "c", "d"}) + CommonTypes() + R"(
+  auto* p = parser(test::Assemble(Names({"a", "b", "c", "d"}) + Preamble() + R"(
      %null_bool = OpConstantNull %bool
      %null_int = OpConstantNull %int
      %null_uint = OpConstantNull %uint
@@ -318,7 +321,7 @@
 }
 
 TEST_F(SpvParserTest, EmitFunctionVariables_VectorInitializer) {
-  auto* p = parser(test::Assemble(CommonTypes() + R"(
+  auto* p = parser(test::Assemble(Preamble() + R"(
      %ptr = OpTypePointer Function %v2float
      %two = OpConstant %float 2.0
      %const = OpConstantComposite %v2float %float_1p5 %two
@@ -351,7 +354,7 @@
 }
 
 TEST_F(SpvParserTest, EmitFunctionVariables_MatrixInitializer) {
-  auto* p = parser(test::Assemble(CommonTypes() + R"(
+  auto* p = parser(test::Assemble(Preamble() + R"(
      %ptr = OpTypePointer Function %m3v2float
      %two = OpConstant %float 2.0
      %three = OpConstant %float 3.0
@@ -402,7 +405,7 @@
 }
 
 TEST_F(SpvParserTest, EmitFunctionVariables_ArrayInitializer) {
-  auto* p = parser(test::Assemble(CommonTypes() + R"(
+  auto* p = parser(test::Assemble(Preamble() + R"(
      %ptr = OpTypePointer Function %arr2uint
      %two = OpConstant %uint 2
      %const = OpConstantComposite %arr2uint %uint_1 %two
@@ -436,7 +439,7 @@
 
 TEST_F(SpvParserTest, EmitFunctionVariables_ArrayInitializer_AliasType) {
   auto* p = parser(test::Assemble(
-      std::string("OpDecorate %arr2uint ArrayStride 16\n") + CommonTypes() + R"(
+      std::string("OpDecorate %arr2uint ArrayStride 16\n") + Preamble() + R"(
      %ptr = OpTypePointer Function %arr2uint
      %two = OpConstant %uint 2
      %const = OpConstantComposite %arr2uint %uint_1 %two
@@ -469,7 +472,7 @@
 }
 
 TEST_F(SpvParserTest, EmitFunctionVariables_ArrayInitializer_Null) {
-  auto* p = parser(test::Assemble(CommonTypes() + R"(
+  auto* p = parser(test::Assemble(Preamble() + R"(
      %ptr = OpTypePointer Function %arr2uint
      %two = OpConstant %uint 2
      %const = OpConstantNull %arr2uint
@@ -503,7 +506,7 @@
 
 TEST_F(SpvParserTest, EmitFunctionVariables_ArrayInitializer_AliasType_Null) {
   auto* p = parser(test::Assemble(
-      std::string("OpDecorate %arr2uint ArrayStride 16\n") + CommonTypes() + R"(
+      std::string("OpDecorate %arr2uint ArrayStride 16\n") + Preamble() + R"(
      %ptr = OpTypePointer Function %arr2uint
      %two = OpConstant %uint 2
      %const = OpConstantNull %arr2uint
@@ -536,7 +539,7 @@
 }
 
 TEST_F(SpvParserTest, EmitFunctionVariables_StructInitializer) {
-  auto* p = parser(test::Assemble(CommonTypes() + R"(
+  auto* p = parser(test::Assemble(Preamble() + R"(
      %ptr = OpTypePointer Function %strct
      %two = OpConstant %uint 2
      %arrconst = OpConstantComposite %arr2uint %uint_1 %two
@@ -575,7 +578,7 @@
 }
 
 TEST_F(SpvParserTest, EmitFunctionVariables_StructInitializer_Null) {
-  auto* p = parser(test::Assemble(CommonTypes() + R"(
+  auto* p = parser(test::Assemble(Preamble() + R"(
      %ptr = OpTypePointer Function %strct
      %two = OpConstant %uint 2
      %arrconst = OpConstantComposite %arr2uint %uint_1 %two
@@ -613,6 +616,184 @@
 )")) << ToString(fe.ast_body());
 }
 
+TEST_F(SpvParserTest,
+       EmitStatement_CombinatorialValue_Defer_UsedOnceSameConstruct) {
+  auto assembly = Preamble() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     %25 = OpVariable %ptr_uint Function
+     %2 = OpIAdd %uint %uint_1 %uint_1
+     OpStore %25 %uint_1 ; Do initial store to mark source location
+     OpBranch %20
+
+     %20 = OpLabel
+     OpStore %25 %2 ; defer emission of the addition until here.
+     OpReturn
+
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+
+  EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(VariableDeclStatement{
+  Variable{
+    x_25
+    function
+    __u32
+  }
+}
+Assignment{
+  Identifier{x_25}
+  ScalarConstructor{1}
+}
+Assignment{
+  Identifier{x_25}
+  Binary{
+    ScalarConstructor{1}
+    add
+    ScalarConstructor{1}
+  }
+}
+Return{}
+)")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest, EmitStatement_CombinatorialValue_Immediate_UsedTwice) {
+  auto assembly = Preamble() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     %25 = OpVariable %ptr_uint Function
+     %2 = OpIAdd %uint %uint_1 %uint_1
+     OpStore %25 %uint_1 ; Do initial store to mark source location
+     OpBranch %20
+
+     %20 = OpLabel
+     OpStore %25 %2
+     OpStore %25 %2
+     OpReturn
+
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+
+  EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(VariableDeclStatement{
+  Variable{
+    x_25
+    function
+    __u32
+  }
+}
+VariableDeclStatement{
+  Variable{
+    x_2
+    none
+    __u32
+    {
+      Binary{
+        ScalarConstructor{1}
+        add
+        ScalarConstructor{1}
+      }
+    }
+  }
+}
+Assignment{
+  Identifier{x_25}
+  ScalarConstructor{1}
+}
+Assignment{
+  Identifier{x_25}
+  Identifier{x_2}
+}
+Assignment{
+  Identifier{x_25}
+  Identifier{x_2}
+}
+Return{}
+)")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest,
+       EmitStatement_CombinatorialValue_Immediate_UsedOnceDifferentConstruct) {
+  // Translation should not sink expensive operations into or out of control
+  // flow. As a simple heuristic, don't move *any* combinatorial operation
+  // across any constrol flow.
+  auto assembly = Preamble() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     %25 = OpVariable %ptr_uint Function
+     %2 = OpIAdd %uint %uint_1 %uint_1
+     OpStore %25 %uint_1 ; Do initial store to mark source location
+     OpBranch %20
+
+     %20 = OpLabel  ; Introduce a new construct
+     OpLoopMerge %99 %80 None
+     OpBranch %80
+
+     %80 = OpLabel
+     OpStore %25 %2  ; store combinatorial value %2, inside the loop
+     OpBranch %20
+
+     %99 = OpLabel ; merge block
+     OpStore %25 %uint_2
+     OpReturn
+
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+
+  EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(VariableDeclStatement{
+  Variable{
+    x_25
+    function
+    __u32
+  }
+}
+VariableDeclStatement{
+  Variable{
+    x_2
+    none
+    __u32
+    {
+      Binary{
+        ScalarConstructor{1}
+        add
+        ScalarConstructor{1}
+      }
+    }
+  }
+}
+Assignment{
+  Identifier{x_25}
+  ScalarConstructor{1}
+}
+Loop{
+  continuing {
+    Assignment{
+      Identifier{x_25}
+      Identifier{x_2}
+    }
+  }
+}
+Assignment{
+  Identifier{x_25}
+  ScalarConstructor{2}
+}
+Return{}
+)")) << ToString(fe.ast_body());
+}
+
 }  // namespace
 }  // namespace spirv
 }  // namespace reader