[type-determiner] Update to work with entry point and function stages.

This Cl updates the type determiner to work with both styles of entry
point definition.

Change-Id: Ic48f1a5f0a5820821f9a74380896426a97483049
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/28666
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: David Neto <dneto@google.com>
Reviewed-by: Sarah Mashayekhi <sarahmashay@google.com>
diff --git a/src/ast/function.cc b/src/ast/function.cc
index 8e8d853..23bd004 100644
--- a/src/ast/function.cc
+++ b/src/ast/function.cc
@@ -17,6 +17,7 @@
 #include <sstream>
 
 #include "src/ast/decorated_variable.h"
+#include "src/ast/stage_decoration.h"
 #include "src/ast/workgroup_decoration.h"
 
 namespace tint {
@@ -56,6 +57,15 @@
   return {1, 1, 1};
 }
 
+ast::PipelineStage Function::pipeline_stage() const {
+  for (const auto& deco : decorations_) {
+    if (deco->IsStage()) {
+      return deco->AsStage()->value();
+    }
+  }
+  return ast::PipelineStage::kNone;
+}
+
 void Function::add_referenced_module_variable(Variable* var) {
   for (const auto* v : referenced_module_vars_) {
     if (v->name() == var->name()) {
diff --git a/src/ast/function.h b/src/ast/function.h
index 078f692..5cb75ba 100644
--- a/src/ast/function.h
+++ b/src/ast/function.h
@@ -28,6 +28,7 @@
 #include "src/ast/function_decoration.h"
 #include "src/ast/location_decoration.h"
 #include "src/ast/node.h"
+#include "src/ast/pipeline_stage.h"
 #include "src/ast/set_decoration.h"
 #include "src/ast/statement.h"
 #include "src/ast/type/type.h"
@@ -100,6 +101,9 @@
   /// return if no workgroup size was set.
   std::tuple<uint32_t, uint32_t, uint32_t> workgroup_size() const;
 
+  /// @returns the functions pipeline stage or None if not set
+  ast::PipelineStage pipeline_stage() const;
+
   /// Adds the given variable to the list of referenced module variables if it
   /// is not already included.
   /// @param var the module variable to add
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index 948b04c..a55ec6c 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -217,6 +217,17 @@
     }
   }
 
+  // Walk over the caller to callee information and update functions with which
+  // entry points call those functions.
+  for (const auto& func : mod_->functions()) {
+    if (func->pipeline_stage() == ast::PipelineStage::kNone) {
+      continue;
+    }
+    for (const auto& callee : caller_to_callee_[func->name()]) {
+      set_entry_points(callee, func->name());
+    }
+  }
+
   return true;
 }
 
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index 512e0b2..67821ad 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -37,9 +37,11 @@
 #include "src/ast/if_statement.h"
 #include "src/ast/loop_statement.h"
 #include "src/ast/member_accessor_expression.h"
+#include "src/ast/pipeline_stage.h"
 #include "src/ast/return_statement.h"
 #include "src/ast/scalar_constructor_expression.h"
 #include "src/ast/sint_literal.h"
+#include "src/ast/stage_decoration.h"
 #include "src/ast/struct.h"
 #include "src/ast/struct_member.h"
 #include "src/ast/switch_statement.h"
@@ -4479,5 +4481,107 @@
   EXPECT_TRUE(ep_2_func_ptr->ancestor_entry_points().empty());
 }
 
+TEST_F(TypeDeterminerTest, Function_EntryPoints_StageDecoration) {
+  ast::type::F32Type f32;
+
+  // fn b() {}
+  // fn c() { b(); }
+  // fn a() { c(); }
+  // fn ep_1() { a(); b(); }
+  // fn ep_2() { c();}
+  //
+  // c -> {ep_1, ep_2}
+  // a -> {ep_1}
+  // b -> {ep_1, ep_2}
+  // ep_1 -> {}
+  // ep_2 -> {}
+
+  ast::VariableList params;
+  auto func_b = std::make_unique<ast::Function>("b", std::move(params), &f32);
+  auto* func_b_ptr = func_b.get();
+
+  auto body = std::make_unique<ast::BlockStatement>();
+  func_b->set_body(std::move(body));
+
+  auto func_c = std::make_unique<ast::Function>("c", std::move(params), &f32);
+  auto* func_c_ptr = func_c.get();
+
+  body = std::make_unique<ast::BlockStatement>();
+  body->append(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("second"),
+      std::make_unique<ast::CallExpression>(
+          std::make_unique<ast::IdentifierExpression>("b"),
+          ast::ExpressionList{})));
+  func_c->set_body(std::move(body));
+
+  auto func_a = std::make_unique<ast::Function>("a", std::move(params), &f32);
+  auto* func_a_ptr = func_a.get();
+
+  body = std::make_unique<ast::BlockStatement>();
+  body->append(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("first"),
+      std::make_unique<ast::CallExpression>(
+          std::make_unique<ast::IdentifierExpression>("c"),
+          ast::ExpressionList{})));
+  func_a->set_body(std::move(body));
+
+  auto ep_1 = std::make_unique<ast::Function>("ep_1", std::move(params), &f32);
+  ep_1->add_decoration(
+      std::make_unique<ast::StageDecoration>(ast::PipelineStage::kVertex));
+  auto* ep_1_ptr = ep_1.get();
+
+  body = std::make_unique<ast::BlockStatement>();
+  body->append(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("call_a"),
+      std::make_unique<ast::CallExpression>(
+          std::make_unique<ast::IdentifierExpression>("a"),
+          ast::ExpressionList{})));
+  body->append(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("call_b"),
+      std::make_unique<ast::CallExpression>(
+          std::make_unique<ast::IdentifierExpression>("b"),
+          ast::ExpressionList{})));
+  ep_1->set_body(std::move(body));
+
+  auto ep_2 = std::make_unique<ast::Function>("ep_2", std::move(params), &f32);
+  ep_2->add_decoration(
+      std::make_unique<ast::StageDecoration>(ast::PipelineStage::kVertex));
+  auto* ep_2_ptr = ep_2.get();
+
+  body = std::make_unique<ast::BlockStatement>();
+  body->append(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("call_c"),
+      std::make_unique<ast::CallExpression>(
+          std::make_unique<ast::IdentifierExpression>("c"),
+          ast::ExpressionList{})));
+  ep_2->set_body(std::move(body));
+
+  mod()->AddFunction(std::move(func_b));
+  mod()->AddFunction(std::move(func_c));
+  mod()->AddFunction(std::move(func_a));
+  mod()->AddFunction(std::move(ep_1));
+  mod()->AddFunction(std::move(ep_2));
+
+  // Register the functions and calculate the callers
+  ASSERT_TRUE(td()->Determine()) << td()->error();
+
+  const auto& b_eps = func_b_ptr->ancestor_entry_points();
+  ASSERT_EQ(2u, b_eps.size());
+  EXPECT_EQ("ep_1", b_eps[0]);
+  EXPECT_EQ("ep_2", b_eps[1]);
+
+  const auto& a_eps = func_a_ptr->ancestor_entry_points();
+  ASSERT_EQ(1u, a_eps.size());
+  EXPECT_EQ("ep_1", a_eps[0]);
+
+  const auto& c_eps = func_c_ptr->ancestor_entry_points();
+  ASSERT_EQ(2u, c_eps.size());
+  EXPECT_EQ("ep_1", c_eps[0]);
+  EXPECT_EQ("ep_2", c_eps[1]);
+
+  EXPECT_TRUE(ep_1_ptr->ancestor_entry_points().empty());
+  EXPECT_TRUE(ep_2_ptr->ancestor_entry_points().empty());
+}
+
 }  // namespace
 }  // namespace tint