[spirv-reader] Don't move combinatorial values across control flow
Avoid sinking expensive operations into control flow such as loops.
The heuristic way to achieve that is to avoid moving combinatorial
values across *any* structured construct boundaries.
Bug: tint:3
Change-Id: I91502b01166a0db64c0e652331591850df75f9d4
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/24140
Reviewed-by: dan sinclair <dsinclair@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 840836b..ad68e73 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -2877,8 +2877,30 @@
}
void FunctionEmitter::RegisterValuesNeedingNamedDefinition() {
- for (auto& block : function_) {
- for (const auto& inst : block) {
+ // Maps a result ID to the block position where it is last used.
+ std::unordered_map<uint32_t, uint32_t> id_to_last_use_pos;
+ // List of pairs of (result id, block position of the definition).
+ std::vector<std::pair<uint32_t, uint32_t>> id_def_pos;
+
+ for (auto block_id : block_order_) {
+ const auto* block_info = GetBlockInfo(block_id);
+ const auto block_pos = block_info->pos;
+
+ for (const auto& inst : *(block_info->basic_block)) {
+ const auto result_id = inst.result_id();
+ if (result_id != 0) {
+ id_def_pos.emplace_back(
+ std::pair<uint32_t, uint32_t>{result_id, block_pos});
+ }
+ inst.ForEachInId(
+ [&id_to_last_use_pos, block_pos](const uint32_t* id_ptr) {
+ // If the id is not in the map already, this will create
+ // an entry with value 0.
+ auto& pos = id_to_last_use_pos[*id_ptr];
+ // Update the entry.
+ pos = std::max(pos, block_pos);
+ });
+
if (inst.opcode() == SpvOpVectorShuffle) {
// We might access the vector operands multiple times. Make sure they
// are evaluated only once.
@@ -2896,6 +2918,27 @@
}
}
}
+
+ // For an ID defined in this function, if it is used in a different construct
+ // than its definition, then it needs a named constant definition. Otherwise
+ // we might sink an expensive computation into control flow, and hence change
+ // performance.
+ for (const auto& id_and_pos : id_def_pos) {
+ const auto id = id_and_pos.first;
+ const auto def_pos = id_and_pos.second;
+
+ auto last_use_where = id_to_last_use_pos.find(id);
+ if (last_use_where != id_to_last_use_pos.end()) {
+ const auto last_use_pos = last_use_where->second;
+ const auto* def_in_construct =
+ GetBlockInfo(block_order_[def_pos])->construct;
+ const auto* last_use_in_construct =
+ GetBlockInfo(block_order_[last_use_pos])->construct;
+ if (def_in_construct != last_use_in_construct) {
+ needs_named_const_def_.insert(id);
+ }
+ }
+ }
}
TypedExpression FunctionEmitter::MakeNumericConversion(
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index fbab63b..9a8ebad 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -112,7 +112,8 @@
/// as its own continue target, and has branch to itself.
bool is_single_block_loop = false;
- /// The immediately enclosing structured construct.
+ /// The immediately enclosing structured construct. If this block is not
+ /// in the block order at all, then this is still nullptr.
const Construct* construct = nullptr;
/// Maps the ID of a successor block (in the CFG) to its edge classification.
diff --git a/src/reader/spirv/function_var_test.cc b/src/reader/spirv/function_var_test.cc
index dd9f283..ca8e217 100644
--- a/src/reader/spirv/function_var_test.cc
+++ b/src/reader/spirv/function_var_test.cc
@@ -26,6 +26,7 @@
namespace spirv {
namespace {
+using ::testing::Eq;
using ::testing::HasSubstr;
/// @returns a SPIR-V assembly segment which assigns debug names
@@ -38,8 +39,11 @@
return outs.str();
}
-std::string CommonTypes() {
+std::string Preamble() {
return R"(
+ OpCapability Shader
+ OpMemoryModel Logical Simple
+
%void = OpTypeVoid
%voidfn = OpTypeFunction %void
@@ -70,7 +74,7 @@
}
TEST_F(SpvParserTest, EmitFunctionVariables_AnonymousVars) {
- auto* p = parser(test::Assemble(CommonTypes() + R"(
+ auto* p = parser(test::Assemble(Preamble() + R"(
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%1 = OpVariable %ptr_uint Function
@@ -108,7 +112,7 @@
}
TEST_F(SpvParserTest, EmitFunctionVariables_NamedVars) {
- auto* p = parser(test::Assemble(Names({"a", "b", "c"}) + CommonTypes() + R"(
+ auto* p = parser(test::Assemble(Names({"a", "b", "c"}) + Preamble() + R"(
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%a = OpVariable %ptr_uint Function
@@ -146,7 +150,7 @@
}
TEST_F(SpvParserTest, EmitFunctionVariables_MixedTypes) {
- auto* p = parser(test::Assemble(Names({"a", "b", "c"}) + CommonTypes() + R"(
+ auto* p = parser(test::Assemble(Names({"a", "b", "c"}) + Preamble() + R"(
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%a = OpVariable %ptr_uint Function
@@ -184,8 +188,8 @@
}
TEST_F(SpvParserTest, EmitFunctionVariables_ScalarInitializers) {
- auto* p = parser(
- test::Assemble(Names({"a", "b", "c", "d", "e"}) + CommonTypes() + R"(
+ auto* p =
+ parser(test::Assemble(Names({"a", "b", "c", "d", "e"}) + Preamble() + R"(
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%a = OpVariable %ptr_bool Function %true
@@ -254,8 +258,7 @@
}
TEST_F(SpvParserTest, EmitFunctionVariables_ScalarNullInitializers) {
- auto* p =
- parser(test::Assemble(Names({"a", "b", "c", "d"}) + CommonTypes() + R"(
+ auto* p = parser(test::Assemble(Names({"a", "b", "c", "d"}) + Preamble() + R"(
%null_bool = OpConstantNull %bool
%null_int = OpConstantNull %int
%null_uint = OpConstantNull %uint
@@ -318,7 +321,7 @@
}
TEST_F(SpvParserTest, EmitFunctionVariables_VectorInitializer) {
- auto* p = parser(test::Assemble(CommonTypes() + R"(
+ auto* p = parser(test::Assemble(Preamble() + R"(
%ptr = OpTypePointer Function %v2float
%two = OpConstant %float 2.0
%const = OpConstantComposite %v2float %float_1p5 %two
@@ -351,7 +354,7 @@
}
TEST_F(SpvParserTest, EmitFunctionVariables_MatrixInitializer) {
- auto* p = parser(test::Assemble(CommonTypes() + R"(
+ auto* p = parser(test::Assemble(Preamble() + R"(
%ptr = OpTypePointer Function %m3v2float
%two = OpConstant %float 2.0
%three = OpConstant %float 3.0
@@ -402,7 +405,7 @@
}
TEST_F(SpvParserTest, EmitFunctionVariables_ArrayInitializer) {
- auto* p = parser(test::Assemble(CommonTypes() + R"(
+ auto* p = parser(test::Assemble(Preamble() + R"(
%ptr = OpTypePointer Function %arr2uint
%two = OpConstant %uint 2
%const = OpConstantComposite %arr2uint %uint_1 %two
@@ -436,7 +439,7 @@
TEST_F(SpvParserTest, EmitFunctionVariables_ArrayInitializer_AliasType) {
auto* p = parser(test::Assemble(
- std::string("OpDecorate %arr2uint ArrayStride 16\n") + CommonTypes() + R"(
+ std::string("OpDecorate %arr2uint ArrayStride 16\n") + Preamble() + R"(
%ptr = OpTypePointer Function %arr2uint
%two = OpConstant %uint 2
%const = OpConstantComposite %arr2uint %uint_1 %two
@@ -469,7 +472,7 @@
}
TEST_F(SpvParserTest, EmitFunctionVariables_ArrayInitializer_Null) {
- auto* p = parser(test::Assemble(CommonTypes() + R"(
+ auto* p = parser(test::Assemble(Preamble() + R"(
%ptr = OpTypePointer Function %arr2uint
%two = OpConstant %uint 2
%const = OpConstantNull %arr2uint
@@ -503,7 +506,7 @@
TEST_F(SpvParserTest, EmitFunctionVariables_ArrayInitializer_AliasType_Null) {
auto* p = parser(test::Assemble(
- std::string("OpDecorate %arr2uint ArrayStride 16\n") + CommonTypes() + R"(
+ std::string("OpDecorate %arr2uint ArrayStride 16\n") + Preamble() + R"(
%ptr = OpTypePointer Function %arr2uint
%two = OpConstant %uint 2
%const = OpConstantNull %arr2uint
@@ -536,7 +539,7 @@
}
TEST_F(SpvParserTest, EmitFunctionVariables_StructInitializer) {
- auto* p = parser(test::Assemble(CommonTypes() + R"(
+ auto* p = parser(test::Assemble(Preamble() + R"(
%ptr = OpTypePointer Function %strct
%two = OpConstant %uint 2
%arrconst = OpConstantComposite %arr2uint %uint_1 %two
@@ -575,7 +578,7 @@
}
TEST_F(SpvParserTest, EmitFunctionVariables_StructInitializer_Null) {
- auto* p = parser(test::Assemble(CommonTypes() + R"(
+ auto* p = parser(test::Assemble(Preamble() + R"(
%ptr = OpTypePointer Function %strct
%two = OpConstant %uint 2
%arrconst = OpConstantComposite %arr2uint %uint_1 %two
@@ -613,6 +616,184 @@
)")) << ToString(fe.ast_body());
}
+TEST_F(SpvParserTest,
+ EmitStatement_CombinatorialValue_Defer_UsedOnceSameConstruct) {
+ auto assembly = Preamble() + R"(
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ %25 = OpVariable %ptr_uint Function
+ %2 = OpIAdd %uint %uint_1 %uint_1
+ OpStore %25 %uint_1 ; Do initial store to mark source location
+ OpBranch %20
+
+ %20 = OpLabel
+ OpStore %25 %2 ; defer emission of the addition until here.
+ OpReturn
+
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+
+ EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(VariableDeclStatement{
+ Variable{
+ x_25
+ function
+ __u32
+ }
+}
+Assignment{
+ Identifier{x_25}
+ ScalarConstructor{1}
+}
+Assignment{
+ Identifier{x_25}
+ Binary{
+ ScalarConstructor{1}
+ add
+ ScalarConstructor{1}
+ }
+}
+Return{}
+)")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest, EmitStatement_CombinatorialValue_Immediate_UsedTwice) {
+ auto assembly = Preamble() + R"(
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ %25 = OpVariable %ptr_uint Function
+ %2 = OpIAdd %uint %uint_1 %uint_1
+ OpStore %25 %uint_1 ; Do initial store to mark source location
+ OpBranch %20
+
+ %20 = OpLabel
+ OpStore %25 %2
+ OpStore %25 %2
+ OpReturn
+
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+
+ EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(VariableDeclStatement{
+ Variable{
+ x_25
+ function
+ __u32
+ }
+}
+VariableDeclStatement{
+ Variable{
+ x_2
+ none
+ __u32
+ {
+ Binary{
+ ScalarConstructor{1}
+ add
+ ScalarConstructor{1}
+ }
+ }
+ }
+}
+Assignment{
+ Identifier{x_25}
+ ScalarConstructor{1}
+}
+Assignment{
+ Identifier{x_25}
+ Identifier{x_2}
+}
+Assignment{
+ Identifier{x_25}
+ Identifier{x_2}
+}
+Return{}
+)")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest,
+ EmitStatement_CombinatorialValue_Immediate_UsedOnceDifferentConstruct) {
+ // Translation should not sink expensive operations into or out of control
+ // flow. As a simple heuristic, don't move *any* combinatorial operation
+ // across any constrol flow.
+ auto assembly = Preamble() + R"(
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ %25 = OpVariable %ptr_uint Function
+ %2 = OpIAdd %uint %uint_1 %uint_1
+ OpStore %25 %uint_1 ; Do initial store to mark source location
+ OpBranch %20
+
+ %20 = OpLabel ; Introduce a new construct
+ OpLoopMerge %99 %80 None
+ OpBranch %80
+
+ %80 = OpLabel
+ OpStore %25 %2 ; store combinatorial value %2, inside the loop
+ OpBranch %20
+
+ %99 = OpLabel ; merge block
+ OpStore %25 %uint_2
+ OpReturn
+
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+
+ EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(VariableDeclStatement{
+ Variable{
+ x_25
+ function
+ __u32
+ }
+}
+VariableDeclStatement{
+ Variable{
+ x_2
+ none
+ __u32
+ {
+ Binary{
+ ScalarConstructor{1}
+ add
+ ScalarConstructor{1}
+ }
+ }
+ }
+}
+Assignment{
+ Identifier{x_25}
+ ScalarConstructor{1}
+}
+Loop{
+ continuing {
+ Assignment{
+ Identifier{x_25}
+ Identifier{x_2}
+ }
+ }
+}
+Assignment{
+ Identifier{x_25}
+ ScalarConstructor{2}
+}
+Return{}
+)")) << ToString(fe.ast_body());
+}
+
} // namespace
} // namespace spirv
} // namespace reader