spirv-reader: register statically accessed inputs and outputs
Bug: tint:508
Change-Id: I585abb0791f5ea0bcb282f12f6940e718da4956d
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/48861
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: James Price <jrprice@google.com>
Auto-Submit: David Neto <dneto@google.com>
diff --git a/src/BUILD.gn b/src/BUILD.gn
index 941e1e7..0cdaac4 100644
--- a/src/BUILD.gn
+++ b/src/BUILD.gn
@@ -581,6 +581,7 @@
sources = [
"reader/spirv/construct.cc",
"reader/spirv/construct.h",
+ "reader/spirv/entry_point_info.cc",
"reader/spirv/entry_point_info.h",
"reader/spirv/enum_converter.cc",
"reader/spirv/enum_converter.h",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 795091f..658f7dd 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -352,6 +352,7 @@
reader/spirv/construct.h
reader/spirv/construct.cc
reader/spirv/entry_point_info.h
+ reader/spirv/entry_point_info.cc
reader/spirv/enum_converter.h
reader/spirv/enum_converter.cc
reader/spirv/fail_stream.h
diff --git a/src/reader/spirv/entry_point_info.cc b/src/reader/spirv/entry_point_info.cc
new file mode 100644
index 0000000..61f1586
--- /dev/null
+++ b/src/reader/spirv/entry_point_info.cc
@@ -0,0 +1,38 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/reader/spirv/entry_point_info.h"
+
+#include <utility>
+
+namespace tint {
+namespace reader {
+namespace spirv {
+
+EntryPointInfo::EntryPointInfo(std::string the_name,
+ ast::PipelineStage the_stage,
+ std::vector<uint32_t>&& the_inputs,
+ std::vector<uint32_t>&& the_outputs)
+ : name(the_name),
+ stage(the_stage),
+ inputs(std::move(the_inputs)),
+ outputs(std::move(the_outputs)) {}
+
+EntryPointInfo::EntryPointInfo(const EntryPointInfo&) = default;
+
+EntryPointInfo::~EntryPointInfo() = default;
+
+} // namespace spirv
+} // namespace reader
+} // namespace tint
diff --git a/src/reader/spirv/entry_point_info.h b/src/reader/spirv/entry_point_info.h
index 8256794..8cb11f3 100644
--- a/src/reader/spirv/entry_point_info.h
+++ b/src/reader/spirv/entry_point_info.h
@@ -16,6 +16,7 @@
#define SRC_READER_SPIRV_ENTRY_POINT_INFO_H_
#include <string>
+#include <vector>
#include "src/ast/pipeline_stage.h"
@@ -25,10 +26,29 @@
/// Entry point information for a function
struct EntryPointInfo {
+ // Constructor.
+ // @param the_name the name of the entry point
+ // @param the_stage the pipeline stage
+ // @param the_inputs list of IDs for Input variables used by the shader
+ // @param the_outputs list of IDs for Output variables used by the shader
+ EntryPointInfo(std::string the_name,
+ ast::PipelineStage the_stage,
+ std::vector<uint32_t>&& the_inputs,
+ std::vector<uint32_t>&& the_outputs);
+ // Copy constructor
+ // @param other the other entry point info to be built from
+ EntryPointInfo(const EntryPointInfo& other);
+ // Destructor
+ ~EntryPointInfo();
+
/// The entry point name
std::string name;
/// The entry point stage
ast::PipelineStage stage = ast::PipelineStage::kNone;
+ /// IDs of pipeline input variables, sorted and without duplicates.
+ std::vector<uint32_t> inputs;
+ /// IDs of pipeline output variables, sorted and without duplicates.
+ std::vector<uint32_t> outputs;
};
} // namespace spirv
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index f147b70..32b32b0 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -14,8 +14,10 @@
#include "src/reader/spirv/parser_impl.h"
+#include <algorithm>
#include <limits>
#include <locale>
+#include <utility>
#include "source/opt/build_module.h"
#include "src/ast/bitcast_expression.h"
@@ -26,6 +28,7 @@
#include "src/sem/depth_texture_type.h"
#include "src/sem/multisampled_texture_type.h"
#include "src/sem/sampled_texture_type.h"
+#include "src/utils/unique_vector.h"
namespace tint {
namespace reader {
@@ -711,8 +714,32 @@
const uint32_t function_id = entry_point.GetSingleWordInOperand(1);
const std::string ep_name = entry_point.GetOperand(2).AsString();
- EntryPointInfo info{ep_name, enum_converter_.ToPipelineStage(stage)};
- function_to_ep_info_[function_id].push_back(info);
+ tint::UniqueVector<uint32_t> inputs;
+ tint::UniqueVector<uint32_t> outputs;
+ for (unsigned iarg = 3; iarg < entry_point.NumInOperands(); iarg++) {
+ const uint32_t var_id = entry_point.GetSingleWordInOperand(iarg);
+ if (const auto* var_inst = def_use_mgr_->GetDef(var_id)) {
+ switch (SpvStorageClass(var_inst->GetSingleWordInOperand(0))) {
+ case SpvStorageClassInput:
+ inputs.add(var_id);
+ break;
+ case SpvStorageClassOutput:
+ outputs.add(var_id);
+ break;
+ default:
+ break;
+ }
+ }
+ }
+ // Save the lists, in ID-sorted order.
+ std::vector<uint32_t> sorted_inputs(inputs.begin(), inputs.end());
+ std::sort(sorted_inputs.begin(), sorted_inputs.end());
+ std::vector<uint32_t> sorted_outputs(outputs.begin(), outputs.end());
+ std::sort(sorted_inputs.begin(), sorted_inputs.end());
+
+ function_to_ep_info_[function_id].emplace_back(
+ ep_name, enum_converter_.ToPipelineStage(stage),
+ std::move(sorted_inputs), std::move(sorted_outputs));
}
// The enum conversion could have failed, so return the existing status value.
return success_;
diff --git a/src/reader/spirv/parser_impl_module_var_test.cc b/src/reader/spirv/parser_impl_module_var_test.cc
index b52daa2..714d8dc 100644
--- a/src/reader/spirv/parser_impl_module_var_test.cc
+++ b/src/reader/spirv/parser_impl_module_var_test.cc
@@ -24,6 +24,7 @@
using SpvModuleScopeVarParserTest = SpvParserTest;
+using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::HasSubstr;
using ::testing::Not;
@@ -3722,6 +3723,104 @@
})")) << module_str;
}
+TEST_F(SpvModuleScopeVarParserTest, RegisterInputOutputVars) {
+ const std::string assembly =
+ R"(
+ OpCapability Shader
+ OpMemoryModel Logical Simple
+ OpEntryPoint GLCompute %1000 "w1000"
+ OpEntryPoint GLCompute %1100 "w1100" %1
+ OpEntryPoint GLCompute %1200 "w1200" %2 %15
+ ; duplication is tolerated prior to SPIR-V 1.4
+ OpEntryPoint GLCompute %1300 "w1300" %1 %15 %2 %1
+
+)" + CommonTypes() +
+ R"(
+
+ %ptr_in_uint = OpTypePointer Input %uint
+ %ptr_out_uint = OpTypePointer Output %uint
+
+ %1 = OpVariable %ptr_in_uint Input
+ %2 = OpVariable %ptr_in_uint Input
+ %5 = OpVariable %ptr_in_uint Input
+ %11 = OpVariable %ptr_out_uint Output
+ %12 = OpVariable %ptr_out_uint Output
+ %15 = OpVariable %ptr_out_uint Output
+
+ %100 = OpFunction %void None %voidfn
+ %entry_100 = OpLabel
+ %load_100 = OpLoad %uint %1
+ OpReturn
+ OpFunctionEnd
+
+ %200 = OpFunction %void None %voidfn
+ %entry_200 = OpLabel
+ %load_200 = OpLoad %uint %2
+ OpStore %15 %load_200
+ OpStore %15 %load_200
+ OpReturn
+ OpFunctionEnd
+
+ %300 = OpFunction %void None %voidfn
+ %entry_300 = OpLabel
+ %dummy_300_1 = OpFunctionCall %void %100
+ %dummy_300_2 = OpFunctionCall %void %200
+ OpReturn
+ OpFunctionEnd
+
+ ; Call nothing
+ %1000 = OpFunction %void None %voidfn
+ %entry_1000 = OpLabel
+ OpReturn
+ OpFunctionEnd
+
+ ; Call %100
+ %1100 = OpFunction %void None %voidfn
+ %entry_1100 = OpLabel
+ %dummy_1100_1 = OpFunctionCall %void %100
+ OpReturn
+ OpFunctionEnd
+
+ ; Call %200
+ %1200 = OpFunction %void None %voidfn
+ %entry_1200 = OpLabel
+ %dummy_1200_1 = OpFunctionCall %void %200
+ OpReturn
+ OpFunctionEnd
+
+ ; Call %300
+ %1300 = OpFunction %void None %voidfn
+ %entry_1300 = OpLabel
+ %dummy_1300_1 = OpFunctionCall %void %300
+ OpReturn
+ OpFunctionEnd
+
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error() << assembly;
+ EXPECT_TRUE(p->error().empty());
+
+ const auto& info_1000 = p->GetEntryPointInfo(1000);
+ EXPECT_EQ(1u, info_1000.size());
+ EXPECT_TRUE(info_1000[0].inputs.empty());
+ EXPECT_TRUE(info_1000[0].outputs.empty());
+
+ const auto& info_1100 = p->GetEntryPointInfo(1100);
+ EXPECT_EQ(1u, info_1100.size());
+ EXPECT_THAT(info_1100[0].inputs, ElementsAre(1));
+ EXPECT_TRUE(info_1100[0].outputs.empty());
+
+ const auto& info_1200 = p->GetEntryPointInfo(1200);
+ EXPECT_EQ(1u, info_1200.size());
+ EXPECT_THAT(info_1200[0].inputs, ElementsAre(2));
+ EXPECT_THAT(info_1200[0].outputs, ElementsAre(15));
+
+ const auto& info_1300 = p->GetEntryPointInfo(1300);
+ EXPECT_EQ(1u, info_1300.size());
+ EXPECT_THAT(info_1300[0].inputs, ElementsAre(1, 2));
+ EXPECT_THAT(info_1300[0].outputs, ElementsAre(15));
+}
+
// TODO(dneto): Test passing pointer to SampleMask as function parameter,
// both input case and output case.
diff --git a/src/reader/spirv/parser_impl_user_name_test.cc b/src/reader/spirv/parser_impl_user_name_test.cc
index d854a9b..99ffbcc 100644
--- a/src/reader/spirv/parser_impl_user_name_test.cc
+++ b/src/reader/spirv/parser_impl_user_name_test.cc
@@ -130,7 +130,7 @@
// has grabbed "main_1" first.
EXPECT_THAT(p->namer().Name(1), Eq("main_1_1"));
- const auto ep_info = p->GetEntryPointInfo(100);
+ const auto& ep_info = p->GetEntryPointInfo(100);
ASSERT_EQ(2u, ep_info.size());
EXPECT_EQ(ep_info[0].name, "main");
EXPECT_EQ(ep_info[1].name, "main_1");