spirv-reader: register statically accessed inputs and outputs

Bug: tint:508
Change-Id: I585abb0791f5ea0bcb282f12f6940e718da4956d
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/48861
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: James Price <jrprice@google.com>
Auto-Submit: David Neto <dneto@google.com>
diff --git a/src/BUILD.gn b/src/BUILD.gn
index 941e1e7..0cdaac4 100644
--- a/src/BUILD.gn
+++ b/src/BUILD.gn
@@ -581,6 +581,7 @@
   sources = [
     "reader/spirv/construct.cc",
     "reader/spirv/construct.h",
+    "reader/spirv/entry_point_info.cc",
     "reader/spirv/entry_point_info.h",
     "reader/spirv/enum_converter.cc",
     "reader/spirv/enum_converter.h",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 795091f..658f7dd 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -352,6 +352,7 @@
     reader/spirv/construct.h
     reader/spirv/construct.cc
     reader/spirv/entry_point_info.h
+    reader/spirv/entry_point_info.cc
     reader/spirv/enum_converter.h
     reader/spirv/enum_converter.cc
     reader/spirv/fail_stream.h
diff --git a/src/reader/spirv/entry_point_info.cc b/src/reader/spirv/entry_point_info.cc
new file mode 100644
index 0000000..61f1586
--- /dev/null
+++ b/src/reader/spirv/entry_point_info.cc
@@ -0,0 +1,38 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/reader/spirv/entry_point_info.h"
+
+#include <utility>
+
+namespace tint {
+namespace reader {
+namespace spirv {
+
+EntryPointInfo::EntryPointInfo(std::string the_name,
+                               ast::PipelineStage the_stage,
+                               std::vector<uint32_t>&& the_inputs,
+                               std::vector<uint32_t>&& the_outputs)
+    : name(the_name),
+      stage(the_stage),
+      inputs(std::move(the_inputs)),
+      outputs(std::move(the_outputs)) {}
+
+EntryPointInfo::EntryPointInfo(const EntryPointInfo&) = default;
+
+EntryPointInfo::~EntryPointInfo() = default;
+
+}  // namespace spirv
+}  // namespace reader
+}  // namespace tint
diff --git a/src/reader/spirv/entry_point_info.h b/src/reader/spirv/entry_point_info.h
index 8256794..8cb11f3 100644
--- a/src/reader/spirv/entry_point_info.h
+++ b/src/reader/spirv/entry_point_info.h
@@ -16,6 +16,7 @@
 #define SRC_READER_SPIRV_ENTRY_POINT_INFO_H_
 
 #include <string>
+#include <vector>
 
 #include "src/ast/pipeline_stage.h"
 
@@ -25,10 +26,29 @@
 
 /// Entry point information for a function
 struct EntryPointInfo {
+  // Constructor.
+  // @param the_name the name of the entry point
+  // @param the_stage the pipeline stage
+  // @param the_inputs list of IDs for Input variables used by the shader
+  // @param the_outputs list of IDs for Output variables used by the shader
+  EntryPointInfo(std::string the_name,
+                 ast::PipelineStage the_stage,
+                 std::vector<uint32_t>&& the_inputs,
+                 std::vector<uint32_t>&& the_outputs);
+  // Copy constructor
+  // @param other the other entry point info to be built from
+  EntryPointInfo(const EntryPointInfo& other);
+  // Destructor
+  ~EntryPointInfo();
+
   /// The entry point name
   std::string name;
   /// The entry point stage
   ast::PipelineStage stage = ast::PipelineStage::kNone;
+  /// IDs of pipeline input variables, sorted and without duplicates.
+  std::vector<uint32_t> inputs;
+  /// IDs of pipeline output variables, sorted and without duplicates.
+  std::vector<uint32_t> outputs;
 };
 
 }  // namespace spirv
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index f147b70..32b32b0 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -14,8 +14,10 @@
 
 #include "src/reader/spirv/parser_impl.h"
 
+#include <algorithm>
 #include <limits>
 #include <locale>
+#include <utility>
 
 #include "source/opt/build_module.h"
 #include "src/ast/bitcast_expression.h"
@@ -26,6 +28,7 @@
 #include "src/sem/depth_texture_type.h"
 #include "src/sem/multisampled_texture_type.h"
 #include "src/sem/sampled_texture_type.h"
+#include "src/utils/unique_vector.h"
 
 namespace tint {
 namespace reader {
@@ -711,8 +714,32 @@
     const uint32_t function_id = entry_point.GetSingleWordInOperand(1);
     const std::string ep_name = entry_point.GetOperand(2).AsString();
 
-    EntryPointInfo info{ep_name, enum_converter_.ToPipelineStage(stage)};
-    function_to_ep_info_[function_id].push_back(info);
+    tint::UniqueVector<uint32_t> inputs;
+    tint::UniqueVector<uint32_t> outputs;
+    for (unsigned iarg = 3; iarg < entry_point.NumInOperands(); iarg++) {
+      const uint32_t var_id = entry_point.GetSingleWordInOperand(iarg);
+      if (const auto* var_inst = def_use_mgr_->GetDef(var_id)) {
+        switch (SpvStorageClass(var_inst->GetSingleWordInOperand(0))) {
+          case SpvStorageClassInput:
+            inputs.add(var_id);
+            break;
+          case SpvStorageClassOutput:
+            outputs.add(var_id);
+            break;
+          default:
+            break;
+        }
+      }
+    }
+    // Save the lists, in ID-sorted order.
+    std::vector<uint32_t> sorted_inputs(inputs.begin(), inputs.end());
+    std::sort(sorted_inputs.begin(), sorted_inputs.end());
+    std::vector<uint32_t> sorted_outputs(outputs.begin(), outputs.end());
+    std::sort(sorted_inputs.begin(), sorted_inputs.end());
+
+    function_to_ep_info_[function_id].emplace_back(
+        ep_name, enum_converter_.ToPipelineStage(stage),
+        std::move(sorted_inputs), std::move(sorted_outputs));
   }
   // The enum conversion could have failed, so return the existing status value.
   return success_;
diff --git a/src/reader/spirv/parser_impl_module_var_test.cc b/src/reader/spirv/parser_impl_module_var_test.cc
index b52daa2..714d8dc 100644
--- a/src/reader/spirv/parser_impl_module_var_test.cc
+++ b/src/reader/spirv/parser_impl_module_var_test.cc
@@ -24,6 +24,7 @@
 
 using SpvModuleScopeVarParserTest = SpvParserTest;
 
+using ::testing::ElementsAre;
 using ::testing::Eq;
 using ::testing::HasSubstr;
 using ::testing::Not;
@@ -3722,6 +3723,104 @@
 })")) << module_str;
 }
 
+TEST_F(SpvModuleScopeVarParserTest, RegisterInputOutputVars) {
+  const std::string assembly =
+      R"(
+    OpCapability Shader
+    OpMemoryModel Logical Simple
+    OpEntryPoint GLCompute %1000 "w1000"
+    OpEntryPoint GLCompute %1100 "w1100" %1
+    OpEntryPoint GLCompute %1200 "w1200" %2 %15
+    ; duplication is tolerated prior to SPIR-V 1.4
+    OpEntryPoint GLCompute %1300 "w1300" %1 %15 %2 %1
+
+)" + CommonTypes() +
+      R"(
+
+    %ptr_in_uint = OpTypePointer Input %uint
+    %ptr_out_uint = OpTypePointer Output %uint
+
+    %1 = OpVariable %ptr_in_uint Input
+    %2 = OpVariable %ptr_in_uint Input
+    %5 = OpVariable %ptr_in_uint Input
+    %11 = OpVariable %ptr_out_uint Output
+    %12 = OpVariable %ptr_out_uint Output
+    %15 = OpVariable %ptr_out_uint Output
+
+    %100 = OpFunction %void None %voidfn
+    %entry_100 = OpLabel
+    %load_100 = OpLoad %uint %1
+    OpReturn
+    OpFunctionEnd
+
+    %200 = OpFunction %void None %voidfn
+    %entry_200 = OpLabel
+    %load_200 = OpLoad %uint %2
+    OpStore %15 %load_200
+    OpStore %15 %load_200
+    OpReturn
+    OpFunctionEnd
+
+    %300 = OpFunction %void None %voidfn
+    %entry_300 = OpLabel
+    %dummy_300_1 = OpFunctionCall %void %100
+    %dummy_300_2 = OpFunctionCall %void %200
+    OpReturn
+    OpFunctionEnd
+
+    ; Call nothing
+    %1000 = OpFunction %void None %voidfn
+    %entry_1000 = OpLabel
+    OpReturn
+    OpFunctionEnd
+
+    ; Call %100
+    %1100 = OpFunction %void None %voidfn
+    %entry_1100 = OpLabel
+    %dummy_1100_1 = OpFunctionCall %void %100
+    OpReturn
+    OpFunctionEnd
+
+    ; Call %200
+    %1200 = OpFunction %void None %voidfn
+    %entry_1200 = OpLabel
+    %dummy_1200_1 = OpFunctionCall %void %200
+    OpReturn
+    OpFunctionEnd
+
+    ; Call %300
+    %1300 = OpFunction %void None %voidfn
+    %entry_1300 = OpLabel
+    %dummy_1300_1 = OpFunctionCall %void %300
+    OpReturn
+    OpFunctionEnd
+
+ )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error() << assembly;
+  EXPECT_TRUE(p->error().empty());
+
+  const auto& info_1000 = p->GetEntryPointInfo(1000);
+  EXPECT_EQ(1u, info_1000.size());
+  EXPECT_TRUE(info_1000[0].inputs.empty());
+  EXPECT_TRUE(info_1000[0].outputs.empty());
+
+  const auto& info_1100 = p->GetEntryPointInfo(1100);
+  EXPECT_EQ(1u, info_1100.size());
+  EXPECT_THAT(info_1100[0].inputs, ElementsAre(1));
+  EXPECT_TRUE(info_1100[0].outputs.empty());
+
+  const auto& info_1200 = p->GetEntryPointInfo(1200);
+  EXPECT_EQ(1u, info_1200.size());
+  EXPECT_THAT(info_1200[0].inputs, ElementsAre(2));
+  EXPECT_THAT(info_1200[0].outputs, ElementsAre(15));
+
+  const auto& info_1300 = p->GetEntryPointInfo(1300);
+  EXPECT_EQ(1u, info_1300.size());
+  EXPECT_THAT(info_1300[0].inputs, ElementsAre(1, 2));
+  EXPECT_THAT(info_1300[0].outputs, ElementsAre(15));
+}
+
 // TODO(dneto): Test passing pointer to SampleMask as function parameter,
 // both input case and output case.
 
diff --git a/src/reader/spirv/parser_impl_user_name_test.cc b/src/reader/spirv/parser_impl_user_name_test.cc
index d854a9b..99ffbcc 100644
--- a/src/reader/spirv/parser_impl_user_name_test.cc
+++ b/src/reader/spirv/parser_impl_user_name_test.cc
@@ -130,7 +130,7 @@
   // has grabbed "main_1" first.
   EXPECT_THAT(p->namer().Name(1), Eq("main_1_1"));
 
-  const auto ep_info = p->GetEntryPointInfo(100);
+  const auto& ep_info = p->GetEntryPointInfo(100);
   ASSERT_EQ(2u, ep_info.size());
   EXPECT_EQ(ep_info[0].name, "main");
   EXPECT_EQ(ep_info[1].name, "main_1");