[spirv-reader] Emit StageDecoration when building the functions

This CL adds the emission of StageDecoration to entry point functions.
EntryPoint nodes are still emitted. We duplicate the function emission
if there are multiple entry points pointing to the same function.

Change-Id: Icb48a063f5c6a30948bbe2c37c7fce7431af5864
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/28665
Reviewed-by: David Neto <dneto@google.com>
Reviewed-by: Sarah Mashayekhi <sarahmashay@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/BUILD.gn b/BUILD.gn
index e75d512..37366c1 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -419,6 +419,7 @@
   sources = [
     "src/reader/spirv/construct.cc",
     "src/reader/spirv/construct.h",
+    "src/reader/spirv/entry_point_info.h",
     "src/reader/spirv/enum_converter.cc",
     "src/reader/spirv/enum_converter.h",
     "src/reader/spirv/fail_stream.h",
@@ -796,7 +797,6 @@
     "src/reader/spirv/namer_test.cc",
     "src/reader/spirv/parser_impl_convert_member_decoration_test.cc",
     "src/reader/spirv/parser_impl_convert_type_test.cc",
-    "src/reader/spirv/parser_impl_entry_point_test.cc",
     "src/reader/spirv/parser_impl_function_decl_test.cc",
     "src/reader/spirv/parser_impl_get_decorations_test.cc",
     "src/reader/spirv/parser_impl_import_test.cc",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index c2edcd0..213cc99 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -222,6 +222,7 @@
   list(APPEND TINT_LIB_SRCS
     reader/spirv/construct.h
     reader/spirv/construct.cc
+    reader/spirv/entry_point_info.h
     reader/spirv/enum_converter.h
     reader/spirv/enum_converter.cc
     reader/spirv/fail_stream.h
@@ -396,7 +397,6 @@
     reader/spirv/namer_test.cc
     reader/spirv/parser_impl_convert_member_decoration_test.cc
     reader/spirv/parser_impl_convert_type_test.cc
-    reader/spirv/parser_impl_entry_point_test.cc
     reader/spirv/parser_impl_function_decl_test.cc
     reader/spirv/parser_impl_get_decorations_test.cc
     reader/spirv/parser_impl_import_test.cc
diff --git a/src/reader/spirv/entry_point_info.h b/src/reader/spirv/entry_point_info.h
new file mode 100644
index 0000000..8256794
--- /dev/null
+++ b/src/reader/spirv/entry_point_info.h
@@ -0,0 +1,38 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_READER_SPIRV_ENTRY_POINT_INFO_H_
+#define SRC_READER_SPIRV_ENTRY_POINT_INFO_H_
+
+#include <string>
+
+#include "src/ast/pipeline_stage.h"
+
+namespace tint {
+namespace reader {
+namespace spirv {
+
+/// Entry point information for a function
+struct EntryPointInfo {
+  /// The entry point name
+  std::string name;
+  /// The entry point stage
+  ast::PipelineStage stage = ast::PipelineStage::kNone;
+};
+
+}  // namespace spirv
+}  // namespace reader
+}  // namespace tint
+
+#endif  // SRC_READER_SPIRV_ENTRY_POINT_INFO_H_
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index f9a3ccb..23052e2 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -48,6 +48,7 @@
 #include "src/ast/return_statement.h"
 #include "src/ast/scalar_constructor_expression.h"
 #include "src/ast/sint_literal.h"
+#include "src/ast/stage_decoration.h"
 #include "src/ast/storage_class.h"
 #include "src/ast/switch_statement.h"
 #include "src/ast/type/bool_type.h"
@@ -447,7 +448,8 @@
 DefInfo::~DefInfo() = default;
 
 FunctionEmitter::FunctionEmitter(ParserImpl* pi,
-                                 const spvtools::opt::Function& function)
+                                 const spvtools::opt::Function& function,
+                                 const EntryPointInfo* ep_info)
     : parser_impl_(*pi),
       ast_module_(pi->get_module()),
       ir_context_(*(pi->ir_context())),
@@ -456,10 +458,15 @@
       type_mgr_(ir_context_.get_type_mgr()),
       fail_stream_(pi->fail_stream()),
       namer_(pi->namer()),
-      function_(function) {
+      function_(function),
+      ep_info_(ep_info) {
   PushNewStatementBlock(nullptr, 0, nullptr);
 }
 
+FunctionEmitter::FunctionEmitter(ParserImpl* pi,
+                                 const spvtools::opt::Function& function)
+    : FunctionEmitter(pi, function, nullptr) {}
+
 FunctionEmitter::~FunctionEmitter() = default;
 
 FunctionEmitter::StatementBlock::StatementBlock(
@@ -583,7 +590,13 @@
     return false;
   }
 
-  const auto name = namer_.Name(function_.result_id());
+  std::string name;
+  if (ep_info_ == nullptr) {
+    name = namer_.Name(function_.result_id());
+  } else {
+    name = ep_info_->name;
+  }
+
   // Surprisingly, the "type id" on an OpFunction is the result type of the
   // function, not the type of the function.  This is the one exceptional case
   // in SPIR-V where the type ID is not the type of the result ID.
@@ -617,6 +630,12 @@
 
   auto ast_fn =
       std::make_unique<ast::Function>(name, std::move(ast_params), ret_ty);
+
+  if (ep_info_ != nullptr) {
+    ast_fn->add_decoration(
+        std::make_unique<ast::StageDecoration>(ep_info_->stage));
+  }
+
   ast_module_.AddFunction(std::move(ast_fn));
 
   return success();
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index b62994f..c6862c0 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -35,6 +35,7 @@
 #include "src/ast/statement.h"
 #include "src/ast/storage_class.h"
 #include "src/reader/spirv/construct.h"
+#include "src/reader/spirv/entry_point_info.h"
 #include "src/reader/spirv/fail_stream.h"
 #include "src/reader/spirv/namer.h"
 #include "src/reader/spirv/parser_impl.h"
@@ -282,6 +283,14 @@
   /// @param pi a ParserImpl which has already executed BuildInternalModule
   /// @param function the function to emit
   FunctionEmitter(ParserImpl* pi, const spvtools::opt::Function& function);
+  /// Creates a FunctionEmitter, and prepares to write to the AST module
+  /// in |pi|.
+  /// @param pi a ParserImpl which has already executed BuildInternalModule
+  /// @param function the function to emit
+  /// @param ep_info entry point information for this function, or nullptr
+  FunctionEmitter(ParserImpl* pi,
+                  const spvtools::opt::Function& function,
+                  const EntryPointInfo* ep_info);
   /// Destructor
   ~FunctionEmitter();
 
@@ -818,6 +827,9 @@
 
   // Structured constructs, where enclosing constructs precede their children.
   ConstructList constructs_;
+
+  // Information about entry point, if this function is referenced by one
+  const EntryPointInfo* ep_info_ = nullptr;
 };
 
 }  // namespace spirv
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index 5b5b2f0..cb05f29 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -526,7 +526,7 @@
   if (!RegisterUserAndStructMemberNames()) {
     return false;
   }
-  if (!EmitEntryPoints()) {
+  if (!RegisterEntryPoints()) {
     return false;
   }
   if (!RegisterTypes()) {
@@ -625,21 +625,20 @@
   return true;
 }
 
-bool ParserImpl::EmitEntryPoints() {
+bool ParserImpl::RegisterEntryPoints() {
   for (const spvtools::opt::Instruction& entry_point :
        module_->entry_points()) {
     const auto stage = SpvExecutionModel(entry_point.GetSingleWordInOperand(0));
     const uint32_t function_id = entry_point.GetSingleWordInOperand(1);
     const std::string ep_name = entry_point.GetOperand(2).AsString();
-    const std::string name = namer_.GetName(function_id);
 
+    EntryPointInfo info{ep_name, enum_converter_.ToPipelineStage(stage)};
     if (!IsValidIdentifier(ep_name)) {
       return Fail() << "entry point name is not a valid WGSL identifier: "
                     << ep_name;
     }
 
-    ast_module_.AddEntryPoint(std::make_unique<ast::EntryPoint>(
-        enum_converter_.ToPipelineStage(stage), ep_name, name));
+    function_to_ep_info_[function_id].push_back(info);
   }
   // The enum conversion could have failed, so return the existing status value.
   return success_;
@@ -1396,8 +1395,21 @@
     if (!success_) {
       return false;
     }
-    FunctionEmitter emitter(this, *f);
-    success_ = emitter.Emit();
+
+    auto id = f->result_id();
+    auto it = function_to_ep_info_.find(id);
+    if (it == function_to_ep_info_.end()) {
+      FunctionEmitter emitter(this, *f, nullptr);
+      success_ = emitter.Emit();
+    } else {
+      for (const auto& ep : it->second) {
+        FunctionEmitter emitter(this, *f, &ep);
+        success_ = emitter.Emit();
+        if (!success_) {
+          return false;
+        }
+      }
+    }
   }
   return success_;
 }
diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h
index 649bdc0..8f6abfa 100644
--- a/src/reader/spirv/parser_impl.h
+++ b/src/reader/spirv/parser_impl.h
@@ -38,6 +38,7 @@
 #include "src/ast/type/alias_type.h"
 #include "src/ast/type/type.h"
 #include "src/reader/reader.h"
+#include "src/reader/spirv/entry_point_info.h"
 #include "src/reader/spirv/enum_converter.h"
 #include "src/reader/spirv/fail_stream.h"
 #include "src/reader/spirv/namer.h"
@@ -239,10 +240,10 @@
   /// @returns true if parser is still successful.
   bool RegisterUserAndStructMemberNames();
 
-  /// Emit entry point AST nodes.
+  /// Register entry point information.
   /// This is a no-op if the parser has already failed.
   /// @returns true if parser is still successful.
-  bool EmitEntryPoints();
+  bool RegisterEntryPoints();
 
   /// Register Tint AST types for SPIR-V types, including type aliases as
   /// needed.  This is a no-op if the parser has already failed.
@@ -489,6 +490,10 @@
   // on the struct.  The new style is to use the StorageBuffer storage class
   // and Block decoration.
   std::unordered_set<uint32_t> remap_buffer_block_type_;
+
+  // Maps function_id to a list of entrypoint information
+  std::unordered_map<uint32_t, std::vector<EntryPointInfo>>
+      function_to_ep_info_;
 };
 
 }  // namespace spirv
diff --git a/src/reader/spirv/parser_impl_entry_point_test.cc b/src/reader/spirv/parser_impl_entry_point_test.cc
deleted file mode 100644
index 507e5f1..0000000
--- a/src/reader/spirv/parser_impl_entry_point_test.cc
+++ /dev/null
@@ -1,95 +0,0 @@
-// Copyright 2020 The Tint Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <string>
-
-#include "gmock/gmock.h"
-#include "src/reader/spirv/parser_impl.h"
-#include "src/reader/spirv/parser_impl_test_helper.h"
-#include "src/reader/spirv/spirv_tools_helpers_test.h"
-
-namespace tint {
-namespace reader {
-namespace spirv {
-namespace {
-
-using ::testing::Eq;
-using ::testing::HasSubstr;
-
-std::string MakeEntryPoint(const std::string& stage,
-                           const std::string& name,
-                           const std::string& id = "42") {
-  return std::string("OpEntryPoint ") + stage + " %" + id + " \"" + name +
-         "\"\n" +  // Give the target ID a definition.
-         "%" + id + " = OpTypeVoid\n";
-}
-
-TEST_F(SpvParserTest, EntryPoint_NoEntryPoint) {
-  auto* p = parser(test::Assemble(""));
-  EXPECT_TRUE(p->BuildAndParseInternalModule());
-  EXPECT_TRUE(p->error().empty());
-  const auto module_ast = p->module().to_str();
-  EXPECT_THAT(module_ast, Not(HasSubstr("EntryPoint")));
-}
-
-TEST_F(SpvParserTest, EntryPoint_Vertex) {
-  auto* p = parser(test::Assemble(MakeEntryPoint("Vertex", "foobar")));
-  EXPECT_TRUE(p->BuildAndParseInternalModule());
-  EXPECT_TRUE(p->error().empty());
-  const auto module_str = p->module().to_str();
-  EXPECT_THAT(module_str,
-              HasSubstr(R"(EntryPoint{vertex as foobar = foobar})"));
-}
-
-TEST_F(SpvParserTest, EntryPoint_Fragment) {
-  auto* p = parser(test::Assemble(MakeEntryPoint("Fragment", "blitz")));
-  EXPECT_TRUE(p->BuildAndParseInternalModule());
-  EXPECT_TRUE(p->error().empty());
-  const auto module_str = p->module().to_str();
-  EXPECT_THAT(module_str,
-              HasSubstr(R"(EntryPoint{fragment as blitz = blitz})"));
-}
-
-TEST_F(SpvParserTest, EntryPoint_Compute) {
-  auto* p = parser(test::Assemble(MakeEntryPoint("GLCompute", "sort")));
-  EXPECT_TRUE(p->BuildAndParseInternalModule());
-  EXPECT_TRUE(p->error().empty());
-  const auto module_str = p->module().to_str();
-  EXPECT_THAT(module_str, HasSubstr(R"(EntryPoint{compute as sort = sort})"));
-}
-
-TEST_F(SpvParserTest, EntryPoint_MultiNameConflict) {
-  auto* p = parser(test::Assemble(MakeEntryPoint("GLCompute", "work", "40") +
-                                  MakeEntryPoint("Vertex", "work", "50") +
-                                  MakeEntryPoint("Fragment", "work", "60")));
-  EXPECT_TRUE(p->BuildAndParseInternalModule());
-  EXPECT_TRUE(p->error().empty());
-  const auto module_str = p->module().to_str();
-  EXPECT_THAT(module_str, HasSubstr(R"(EntryPoint{compute as work = work})"));
-  EXPECT_THAT(module_str, HasSubstr(R"(EntryPoint{vertex as work = work_1})"));
-  EXPECT_THAT(module_str,
-              HasSubstr(R"(EntryPoint{fragment as work = work_2})"));
-}
-
-TEST_F(SpvParserTest, EntryPoint_MustBeWgslIdentifier) {
-  auto* p = parser(test::Assemble(MakeEntryPoint("GLCompute", ".1234")));
-  EXPECT_FALSE(p->BuildAndParseInternalModule());
-  EXPECT_THAT(p->error(),
-              Eq("entry point name is not a valid WGSL identifier: .1234"));
-}
-
-}  // namespace
-}  // namespace spirv
-}  // namespace reader
-}  // namespace tint
diff --git a/src/reader/spirv/parser_impl_function_decl_test.cc b/src/reader/spirv/parser_impl_function_decl_test.cc
index 7fc86d4..e840830 100644
--- a/src/reader/spirv/parser_impl_function_decl_test.cc
+++ b/src/reader/spirv/parser_impl_function_decl_test.cc
@@ -68,6 +68,89 @@
   EXPECT_THAT(module_ast, Not(HasSubstr("Function{")));
 }
 
+TEST_F(SpvParserTest, EmitFunctions_Function_EntryPoint_Vertex) {
+  std::string input = Names({"main"}) + R"(OpEntryPoint Vertex %main "main"
+)" + CommonTypes() + R"(
+%main = OpFunction %void None %voidfn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd)";
+
+  auto* p = parser(test::Assemble(input));
+  ASSERT_TRUE(p->BuildAndParseInternalModule());
+  ASSERT_TRUE(p->error().empty()) << p->error();
+  const auto module_ast = p->module().to_str();
+  EXPECT_THAT(module_ast, HasSubstr(R"(
+  Function main -> __void
+  StageDecoration{vertex}
+  ()
+  {)"));
+}
+
+TEST_F(SpvParserTest, EmitFunctions_Function_EntryPoint_Fragment) {
+  std::string input = Names({"main"}) + R"(OpEntryPoint Fragment %main "main"
+)" + CommonTypes() + R"(
+%main = OpFunction %void None %voidfn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd)";
+
+  auto* p = parser(test::Assemble(input));
+  ASSERT_TRUE(p->BuildAndParseInternalModule());
+  ASSERT_TRUE(p->error().empty()) << p->error();
+  const auto module_ast = p->module().to_str();
+  EXPECT_THAT(module_ast, HasSubstr(R"(
+  Function main -> __void
+  StageDecoration{fragment}
+  ()
+  {)"));
+}
+
+TEST_F(SpvParserTest, EmitFunctions_Function_EntryPoint_GLCompute) {
+  std::string input = Names({"main"}) + R"(OpEntryPoint GLCompute %main "main"
+)" + CommonTypes() + R"(
+%main = OpFunction %void None %voidfn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd)";
+
+  auto* p = parser(test::Assemble(input));
+  ASSERT_TRUE(p->BuildAndParseInternalModule());
+  ASSERT_TRUE(p->error().empty()) << p->error();
+  const auto module_ast = p->module().to_str();
+  EXPECT_THAT(module_ast, HasSubstr(R"(
+  Function main -> __void
+  StageDecoration{compute}
+  ()
+  {)"));
+}
+
+TEST_F(SpvParserTest, EmitFunctions_Function_EntryPoint_MultipleEntryPoints) {
+  std::string input = Names({"main"}) +
+                      R"(OpEntryPoint GLCompute %main "comp_main"
+OpEntryPoint Fragment %main "frag_main"
+)" + CommonTypes() + R"(
+%main = OpFunction %void None %voidfn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd)";
+
+  auto* p = parser(test::Assemble(input));
+  ASSERT_TRUE(p->BuildAndParseInternalModule());
+  ASSERT_TRUE(p->error().empty()) << p->error();
+  const auto module_ast = p->module().to_str();
+  EXPECT_THAT(module_ast, HasSubstr(R"(
+  Function frag_main -> __void
+  StageDecoration{fragment}
+  ()
+  {)"));
+  EXPECT_THAT(module_ast, HasSubstr(R"(
+  Function comp_main -> __void
+  StageDecoration{compute}
+  ()
+  {)"));
+}
+
 TEST_F(SpvParserTest, EmitFunctions_VoidFunctionWithoutParams) {
   auto* p = parser(test::Assemble(Names({"main"}) + CommonTypes() + R"(
      %main = OpFunction %void None %voidfn