[spirv-reader] Improve placement of hoisted vars
When we hoist a variable out of a continue construct, put it
in associated loop construct, if it exists. This reduces its
lifetime in WGSL, and easier to understand as a code reader.
Change-Id: I8f0cc37640bfe67874cbc27b55029e79e9a8992c
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/24321
Reviewed-by: dan sinclair <dsinclair@google.com>
diff --git a/src/reader/spirv/construct.h b/src/reader/spirv/construct.h
index 767d5de..5e6f754 100644
--- a/src/reader/spirv/construct.h
+++ b/src/reader/spirv/construct.h
@@ -218,6 +218,13 @@
return ss.str();
}
+/// Converts a construct to a string.
+/// @param c the construct
+/// @returns the string representation
+inline std::string ToString(const Construct* c) {
+ return c ? ToString(*c) : ToStringBrief(c);
+}
+
/// Converts a unique pointer to a construct to a string.
/// @param c the construct
/// @returns the string representation
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index dad106c..df4dc52 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -1145,6 +1145,30 @@
return nullptr;
}
+const Construct* FunctionEmitter::SiblingLoopConstruct(
+ const Construct* c) const {
+ if (c == nullptr || c->kind != Construct::kContinue) {
+ return nullptr;
+ }
+ const uint32_t continue_target_id = c->begin_id;
+ const auto* continue_target = GetBlockInfo(continue_target_id);
+ const uint32_t header_id = continue_target->header_for_continue;
+ if (continue_target_id == header_id) {
+ // The continue target is the whole loop.
+ return nullptr;
+ }
+ const auto* candidate = GetBlockInfo(header_id)->construct;
+ // Walk up the construct tree until we hit the loop. In future
+ // we might handle the corner case where the same block is both a
+ // loop header and a selection header. For example, where the
+ // loop header block has a conditional branch going to distinct
+ // targets inside the loop body.
+ while (candidate && candidate->kind != Construct::kLoop) {
+ candidate = candidate->parent;
+ }
+ return candidate;
+}
+
bool FunctionEmitter::ClassifyCFGEdges() {
if (failed()) {
return false;
@@ -3013,13 +3037,17 @@
const auto block_pos = block_info->pos;
for (const auto& inst : *(block_info->basic_block)) {
// Update the usage span for IDs used by this instruction.
- inst.ForEachInId([this, block_pos](const uint32_t* id_ptr) {
- auto* def_info = GetDefInfo(*id_ptr);
- if (def_info) {
- def_info->num_uses++;
- def_info->last_use_pos = std::max(def_info->last_use_pos, block_pos);
- }
- });
+ // But skip uses in OpPhi because they are handled differently.
+ if (inst.opcode() != SpvOpPhi) {
+ inst.ForEachInId([this, block_pos](const uint32_t* id_ptr) {
+ auto* def_info = GetDefInfo(*id_ptr);
+ if (def_info) {
+ def_info->num_uses++;
+ def_info->last_use_pos =
+ std::max(def_info->last_use_pos, block_pos);
+ }
+ });
+ }
if (inst.opcode() == SpvOpPhi) {
// Declare a name for the variable used to carry values to a phi.
@@ -3051,7 +3079,7 @@
// Schedule the declaration of the state variable.
const auto* enclosing_construct =
- GetSmallestEnclosingConstruct(first_pos, last_pos);
+ GetEnclosingScope(first_pos, last_pos);
GetBlockInfo(enclosing_construct->begin_id)
->phis_needing_state_vars.push_back(phi_id);
}
@@ -3088,7 +3116,7 @@
if (def_in_construct != construct_with_last_use) {
const auto* enclosing_construct =
- GetSmallestEnclosingConstruct(first_pos, last_use_pos);
+ GetEnclosingScope(first_pos, last_use_pos);
if (enclosing_construct == def_in_construct) {
// We can use a plain 'const' definition.
def_info->requires_named_const_def = true;
@@ -3103,15 +3131,19 @@
}
}
-const Construct* FunctionEmitter::GetSmallestEnclosingConstruct(
- uint32_t first_pos,
- uint32_t last_pos) const {
+const Construct* FunctionEmitter::GetEnclosingScope(uint32_t first_pos,
+ uint32_t last_pos) const {
const auto* enclosing_construct =
GetBlockInfo(block_order_[first_pos])->construct;
assert(enclosing_construct != nullptr);
// Constructs are strictly nesting, so follow parent pointers
while (enclosing_construct && !enclosing_construct->ContainsPos(last_pos)) {
- enclosing_construct = enclosing_construct->parent;
+ // The scope of a continue construct is enclosed in its associated loop
+ // construct, but they are siblings in our construct tree.
+ const auto* sibling_loop = SiblingLoopConstruct(enclosing_construct);
+ // Go to the sibling loop if it exists, otherwise walk up to the parent.
+ enclosing_construct =
+ sibling_loop ? sibling_loop : enclosing_construct->parent;
}
// At worst, we go all the way out to the function construct.
assert(enclosing_construct != nullptr);
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index 7044b0f..606bcde 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -601,14 +601,23 @@
return where->second.get();
}
- /// Returns the most deeply nested structured construct which encloses
- /// both block positions. Each position must be a valid index into the
- /// function block order array.
+ /// Returns the most deeply nested structured construct which encloses the
+ /// WGSL scopes of names declared in both block positions. Each position must
+ /// be a valid index into the function block order array.
/// @param first_pos the first block position
/// @param last_pos the last block position
/// @returns the smallest construct containing both positions
- const Construct* GetSmallestEnclosingConstruct(uint32_t first_pos,
- uint32_t last_pos) const;
+ const Construct* GetEnclosingScope(uint32_t first_pos,
+ uint32_t last_pos) const;
+
+ /// Finds loop construct associated with a continue construct, if it exists.
+ /// Returns nullptr if:
+ /// - the given construct is not a continue construct
+ /// - the continue construct does not have an associated loop construct
+ /// (the continue target is also the loop header block)
+ /// @param c the continue construct
+ /// @returns the associated loop construct, or nullptr
+ const Construct* SiblingLoopConstruct(const Construct* c) const;
private:
/// @returns the store type for the OpVariable instruction, or
diff --git a/src/reader/spirv/function_cfg_test.cc b/src/reader/spirv/function_cfg_test.cc
index d245427..188b3b9 100644
--- a/src/reader/spirv/function_cfg_test.cc
+++ b/src/reader/spirv/function_cfg_test.cc
@@ -17,6 +17,7 @@
#include <vector>
#include "gmock/gmock.h"
+#include "src/reader/spirv/construct.h"
#include "src/reader/spirv/function.h"
#include "src/reader/spirv/parser_impl.h"
#include "src/reader/spirv/parser_impl_test_helper.h"
@@ -87,15 +88,22 @@
)";
}
-/// Runs the necessary flow until and including finding switch case
-/// headers.
-/// @returns the result of finding switch case headers.
-bool FlowFindSwitchCaseHeaders(FunctionEmitter* fe) {
+/// Runs the necessary flow until and including labeling control
+/// flow constructs.
+/// @returns the result of labeling control flow constructs.
+bool FlowLabelControlFlowConstructs(FunctionEmitter* fe) {
fe->RegisterBasicBlocks();
EXPECT_TRUE(fe->RegisterMerges()) << fe->parser()->error();
fe->ComputeBlockOrderAndPositions();
EXPECT_TRUE(fe->VerifyHeaderContinueMergeOrder()) << fe->parser()->error();
- EXPECT_TRUE(fe->LabelControlFlowConstructs()) << fe->parser()->error();
+ return fe->LabelControlFlowConstructs();
+}
+
+/// Runs the necessary flow until and including finding switch case
+/// headers.
+/// @returns the result of finding switch case headers.
+bool FlowFindSwitchCaseHeaders(FunctionEmitter* fe) {
+ EXPECT_TRUE(FlowLabelControlFlowConstructs(fe)) << fe->parser()->error();
return fe->FindSwitchCaseHeaders();
}
@@ -13418,6 +13426,121 @@
<< ToString(fe.ast_body());
}
+TEST_F(SpvParserTest, SiblingLoopConstruct_Null) {
+ auto assembly = CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+ %10 = OpLabel
+ OpReturn
+ OpFunctionEnd
+)";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_EQ(fe.SiblingLoopConstruct(nullptr), nullptr);
+}
+
+TEST_F(SpvParserTest, SiblingLoopConstruct_NotAContinue) {
+ auto assembly = CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ OpReturn
+
+ OpFunctionEnd
+)";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ ASSERT_TRUE(FlowLabelControlFlowConstructs(&fe)) << p->error();
+ const Construct* c = fe.GetBlockInfo(10)->construct;
+ EXPECT_NE(c, nullptr);
+ EXPECT_EQ(fe.SiblingLoopConstruct(c), nullptr);
+}
+
+TEST_F(SpvParserTest, SiblingLoopConstruct_SingleBlockLoop) {
+ auto assembly = CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ OpBranch %20
+
+ %20 = OpLabel
+ OpLoopMerge %99 %20 None
+ OpBranchConditional %cond %20 %99
+
+ %99 = OpLabel
+ OpReturn
+
+ OpFunctionEnd
+)";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ ASSERT_TRUE(FlowLabelControlFlowConstructs(&fe)) << p->error();
+ const Construct* c = fe.GetBlockInfo(20)->construct;
+ EXPECT_EQ(c->kind, Construct::kContinue);
+ EXPECT_EQ(fe.SiblingLoopConstruct(c), nullptr);
+}
+
+TEST_F(SpvParserTest, SiblingLoopConstruct_ContinueIsWholeMultiBlockLoop) {
+ auto assembly = CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ OpBranch %20
+
+ %20 = OpLabel
+ OpLoopMerge %99 %20 None ; continue target is also loop header
+ OpBranchConditional %cond %30 %99
+
+ %30 = OpLabel
+ OpBranch %20
+
+ %99 = OpLabel
+ OpReturn
+
+ OpFunctionEnd
+)";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
+ << p->error() << assembly;
+ FunctionEmitter fe(p, *spirv_function(100));
+ ASSERT_TRUE(FlowLabelControlFlowConstructs(&fe)) << p->error();
+ const Construct* c = fe.GetBlockInfo(20)->construct;
+ EXPECT_EQ(c->kind, Construct::kContinue);
+ EXPECT_EQ(fe.SiblingLoopConstruct(c), nullptr);
+}
+
+TEST_F(SpvParserTest, SiblingLoopConstruct_HasSiblingLoop) {
+ auto assembly = CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ OpBranch %20
+
+ %20 = OpLabel
+ OpLoopMerge %99 %30 None
+ OpBranchConditional %cond %30 %99
+
+ %30 = OpLabel ; continue target
+ OpBranch %20
+
+ %99 = OpLabel
+ OpReturn
+
+ OpFunctionEnd
+)";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ ASSERT_TRUE(FlowLabelControlFlowConstructs(&fe)) << p->error();
+ const Construct* c = fe.GetBlockInfo(30)->construct;
+ EXPECT_EQ(c->kind, Construct::kContinue);
+ EXPECT_THAT(ToString(fe.SiblingLoopConstruct(c)),
+ Eq("Construct{ Loop [1,2) begin_id:20 end_id:30 depth:1 "
+ "parent:Function@10 in-l:Loop@20 }"));
+}
+
} // namespace
} // namespace spirv
} // namespace reader
diff --git a/src/reader/spirv/function_var_test.cc b/src/reader/spirv/function_var_test.cc
index 6693ae7..c591175 100644
--- a/src/reader/spirv/function_var_test.cc
+++ b/src/reader/spirv/function_var_test.cc
@@ -1042,13 +1042,23 @@
}
}
}
+ VariableDeclStatement{
+ Variable{
+ x_4
+ none
+ __u32
+ {
+ Binary{
+ Identifier{x_2}
+ add
+ ScalarConstructor{1}
+ }
+ }
+ }
+ }
Assignment{
Identifier{x_2_phi}
- Binary{
- Identifier{x_2}
- add
- ScalarConstructor{1}
- }
+ Identifier{x_4}
}
Assignment{
Identifier{x_3_phi}
@@ -1122,13 +1132,6 @@
}
VariableDeclStatement{
Variable{
- x_4
- function
- __u32
- }
- }
- VariableDeclStatement{
- Variable{
x_2_phi
function
__u32
@@ -1201,12 +1204,18 @@
}
}
continuing {
- Assignment{
- Identifier{x_4}
- Binary{
- Identifier{x_2}
- add
- ScalarConstructor{1}
+ VariableDeclStatement{
+ Variable{
+ x_4
+ none
+ __u32
+ {
+ Binary{
+ Identifier{x_2}
+ add
+ ScalarConstructor{1}
+ }
+ }
}
}
Assignment{
@@ -1224,6 +1233,174 @@
)")) << ToString(fe.ast_body());
}
+TEST_F(SpvParserTest, EmitStatement_Phi_ValueFromLoopBodyAndContinuing) {
+ auto assembly = Preamble() + R"(
+ %pty = OpTypePointer Private %uint
+ %1 = OpVariable %pty Private
+ %boolpty = OpTypePointer Private %bool
+ %17 = OpVariable %boolpty Private
+
+ %100 = OpFunction %void None %voidfn
+
+ %9 = OpLabel
+ %101 = OpLoad %bool %17
+ OpBranch %10
+
+ ; Use an outer loop to show we put the new variable in the
+ ; smallest enclosing scope.
+ %10 = OpLabel
+ OpLoopMerge %99 %89 None
+ OpBranch %20
+
+ %20 = OpLabel
+ %2 = OpPhi %uint %uint_0 %10 %4 %30 ; gets computed value
+ %5 = OpPhi %uint %uint_1 %10 %7 %30
+ %4 = OpIAdd %uint %2 %uint_1 ; define %4
+ %6 = OpIAdd %uint %4 %uint_1 ; use %4
+ OpLoopMerge %89 %30 None
+ OpBranchConditional %101 %89 %30
+
+ %30 = OpLabel
+ %7 = OpIAdd %uint %4 %6 ; use %4 again
+ OpBranch %20
+
+ %89 = OpLabel
+ OpBranch %10
+
+ %99 = OpLabel
+ OpReturn
+
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
+ << assembly << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+
+ EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(VariableDeclStatement{
+ Variable{
+ x_101
+ none
+ __bool
+ {
+ Identifier{x_17}
+ }
+ }
+}
+Loop{
+ VariableDeclStatement{
+ Variable{
+ x_4
+ function
+ __u32
+ }
+ }
+ VariableDeclStatement{
+ Variable{
+ x_6
+ function
+ __u32
+ }
+ }
+ VariableDeclStatement{
+ Variable{
+ x_2_phi
+ function
+ __u32
+ }
+ }
+ VariableDeclStatement{
+ Variable{
+ x_5_phi
+ function
+ __u32
+ }
+ }
+ Assignment{
+ Identifier{x_2_phi}
+ ScalarConstructor{0}
+ }
+ Assignment{
+ Identifier{x_5_phi}
+ ScalarConstructor{1}
+ }
+ Loop{
+ VariableDeclStatement{
+ Variable{
+ x_2
+ none
+ __u32
+ {
+ Identifier{x_2_phi}
+ }
+ }
+ }
+ VariableDeclStatement{
+ Variable{
+ x_5
+ none
+ __u32
+ {
+ Identifier{x_5_phi}
+ }
+ }
+ }
+ Assignment{
+ Identifier{x_4}
+ Binary{
+ Identifier{x_2}
+ add
+ ScalarConstructor{1}
+ }
+ }
+ Assignment{
+ Identifier{x_6}
+ Binary{
+ Identifier{x_4}
+ add
+ ScalarConstructor{1}
+ }
+ }
+ If{
+ (
+ Identifier{x_101}
+ )
+ {
+ Break{}
+ }
+ }
+ continuing {
+ VariableDeclStatement{
+ Variable{
+ x_7
+ none
+ __u32
+ {
+ Binary{
+ Identifier{x_4}
+ add
+ Identifier{x_6}
+ }
+ }
+ }
+ }
+ Assignment{
+ Identifier{x_2_phi}
+ Identifier{x_4}
+ }
+ Assignment{
+ Identifier{x_5_phi}
+ Identifier{x_7}
+ }
+ }
+ }
+}
+Return{}
+)")) << ToString(fe.ast_body())
+ << assembly;
+}
+
TEST_F(SpvParserTest, EmitStatement_Phi_FromElseAndThen) {
auto assembly = Preamble() + R"(
%pty = OpTypePointer Private %uint