[type-determiner] Update to work with entry point and function stages.
This Cl updates the type determiner to work with both styles of entry
point definition.
Change-Id: Ic48f1a5f0a5820821f9a74380896426a97483049
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/28666
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: David Neto <dneto@google.com>
Reviewed-by: Sarah Mashayekhi <sarahmashay@google.com>
diff --git a/src/ast/function.cc b/src/ast/function.cc
index 8e8d853..23bd004 100644
--- a/src/ast/function.cc
+++ b/src/ast/function.cc
@@ -17,6 +17,7 @@
#include <sstream>
#include "src/ast/decorated_variable.h"
+#include "src/ast/stage_decoration.h"
#include "src/ast/workgroup_decoration.h"
namespace tint {
@@ -56,6 +57,15 @@
return {1, 1, 1};
}
+ast::PipelineStage Function::pipeline_stage() const {
+ for (const auto& deco : decorations_) {
+ if (deco->IsStage()) {
+ return deco->AsStage()->value();
+ }
+ }
+ return ast::PipelineStage::kNone;
+}
+
void Function::add_referenced_module_variable(Variable* var) {
for (const auto* v : referenced_module_vars_) {
if (v->name() == var->name()) {
diff --git a/src/ast/function.h b/src/ast/function.h
index 078f692..5cb75ba 100644
--- a/src/ast/function.h
+++ b/src/ast/function.h
@@ -28,6 +28,7 @@
#include "src/ast/function_decoration.h"
#include "src/ast/location_decoration.h"
#include "src/ast/node.h"
+#include "src/ast/pipeline_stage.h"
#include "src/ast/set_decoration.h"
#include "src/ast/statement.h"
#include "src/ast/type/type.h"
@@ -100,6 +101,9 @@
/// return if no workgroup size was set.
std::tuple<uint32_t, uint32_t, uint32_t> workgroup_size() const;
+ /// @returns the functions pipeline stage or None if not set
+ ast::PipelineStage pipeline_stage() const;
+
/// Adds the given variable to the list of referenced module variables if it
/// is not already included.
/// @param var the module variable to add
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index 948b04c..a55ec6c 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -217,6 +217,17 @@
}
}
+ // Walk over the caller to callee information and update functions with which
+ // entry points call those functions.
+ for (const auto& func : mod_->functions()) {
+ if (func->pipeline_stage() == ast::PipelineStage::kNone) {
+ continue;
+ }
+ for (const auto& callee : caller_to_callee_[func->name()]) {
+ set_entry_points(callee, func->name());
+ }
+ }
+
return true;
}
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index 512e0b2..67821ad 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -37,9 +37,11 @@
#include "src/ast/if_statement.h"
#include "src/ast/loop_statement.h"
#include "src/ast/member_accessor_expression.h"
+#include "src/ast/pipeline_stage.h"
#include "src/ast/return_statement.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/sint_literal.h"
+#include "src/ast/stage_decoration.h"
#include "src/ast/struct.h"
#include "src/ast/struct_member.h"
#include "src/ast/switch_statement.h"
@@ -4479,5 +4481,107 @@
EXPECT_TRUE(ep_2_func_ptr->ancestor_entry_points().empty());
}
+TEST_F(TypeDeterminerTest, Function_EntryPoints_StageDecoration) {
+ ast::type::F32Type f32;
+
+ // fn b() {}
+ // fn c() { b(); }
+ // fn a() { c(); }
+ // fn ep_1() { a(); b(); }
+ // fn ep_2() { c();}
+ //
+ // c -> {ep_1, ep_2}
+ // a -> {ep_1}
+ // b -> {ep_1, ep_2}
+ // ep_1 -> {}
+ // ep_2 -> {}
+
+ ast::VariableList params;
+ auto func_b = std::make_unique<ast::Function>("b", std::move(params), &f32);
+ auto* func_b_ptr = func_b.get();
+
+ auto body = std::make_unique<ast::BlockStatement>();
+ func_b->set_body(std::move(body));
+
+ auto func_c = std::make_unique<ast::Function>("c", std::move(params), &f32);
+ auto* func_c_ptr = func_c.get();
+
+ body = std::make_unique<ast::BlockStatement>();
+ body->append(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("second"),
+ std::make_unique<ast::CallExpression>(
+ std::make_unique<ast::IdentifierExpression>("b"),
+ ast::ExpressionList{})));
+ func_c->set_body(std::move(body));
+
+ auto func_a = std::make_unique<ast::Function>("a", std::move(params), &f32);
+ auto* func_a_ptr = func_a.get();
+
+ body = std::make_unique<ast::BlockStatement>();
+ body->append(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("first"),
+ std::make_unique<ast::CallExpression>(
+ std::make_unique<ast::IdentifierExpression>("c"),
+ ast::ExpressionList{})));
+ func_a->set_body(std::move(body));
+
+ auto ep_1 = std::make_unique<ast::Function>("ep_1", std::move(params), &f32);
+ ep_1->add_decoration(
+ std::make_unique<ast::StageDecoration>(ast::PipelineStage::kVertex));
+ auto* ep_1_ptr = ep_1.get();
+
+ body = std::make_unique<ast::BlockStatement>();
+ body->append(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("call_a"),
+ std::make_unique<ast::CallExpression>(
+ std::make_unique<ast::IdentifierExpression>("a"),
+ ast::ExpressionList{})));
+ body->append(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("call_b"),
+ std::make_unique<ast::CallExpression>(
+ std::make_unique<ast::IdentifierExpression>("b"),
+ ast::ExpressionList{})));
+ ep_1->set_body(std::move(body));
+
+ auto ep_2 = std::make_unique<ast::Function>("ep_2", std::move(params), &f32);
+ ep_2->add_decoration(
+ std::make_unique<ast::StageDecoration>(ast::PipelineStage::kVertex));
+ auto* ep_2_ptr = ep_2.get();
+
+ body = std::make_unique<ast::BlockStatement>();
+ body->append(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("call_c"),
+ std::make_unique<ast::CallExpression>(
+ std::make_unique<ast::IdentifierExpression>("c"),
+ ast::ExpressionList{})));
+ ep_2->set_body(std::move(body));
+
+ mod()->AddFunction(std::move(func_b));
+ mod()->AddFunction(std::move(func_c));
+ mod()->AddFunction(std::move(func_a));
+ mod()->AddFunction(std::move(ep_1));
+ mod()->AddFunction(std::move(ep_2));
+
+ // Register the functions and calculate the callers
+ ASSERT_TRUE(td()->Determine()) << td()->error();
+
+ const auto& b_eps = func_b_ptr->ancestor_entry_points();
+ ASSERT_EQ(2u, b_eps.size());
+ EXPECT_EQ("ep_1", b_eps[0]);
+ EXPECT_EQ("ep_2", b_eps[1]);
+
+ const auto& a_eps = func_a_ptr->ancestor_entry_points();
+ ASSERT_EQ(1u, a_eps.size());
+ EXPECT_EQ("ep_1", a_eps[0]);
+
+ const auto& c_eps = func_c_ptr->ancestor_entry_points();
+ ASSERT_EQ(2u, c_eps.size());
+ EXPECT_EQ("ep_1", c_eps[0]);
+ EXPECT_EQ("ep_2", c_eps[1]);
+
+ EXPECT_TRUE(ep_1_ptr->ancestor_entry_points().empty());
+ EXPECT_TRUE(ep_2_ptr->ancestor_entry_points().empty());
+}
+
} // namespace
} // namespace tint