Inherit refererenced globals up the call stack.
A given caller should inherit the globals referenced from a callee. This
way, a given entry point will have a list of all the variables used up
the stack which it needs to reference.
Change-Id: Ib6efcdd5c3347749ad2d54aecfa425bd966a62fd
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/24762
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index 82d45a4..6b1d139 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -477,6 +477,18 @@
} else {
if (current_function_) {
caller_to_callee_[current_function_->name()].push_back(ident->name());
+
+ auto* callee_func = mod_->FindFunctionByName(ident->name());
+ if (callee_func == nullptr) {
+ set_error(expr->source(),
+ "unable to find called function: " + ident->name());
+ return false;
+ }
+
+ // We inherit any referenced variables from the callee.
+ for (auto* var : callee_func->referenced_module_variables()) {
+ set_referenced_from_function_if_needed(var);
+ }
}
// An identifier with a single name is a function call, not an import
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index 183bbc2..983ac36 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -803,6 +803,76 @@
EXPECT_EQ(vars[4], priv_ptr);
}
+TEST_F(TypeDeterminerTest, Function_RegisterInputOutputVariables_SubFunction) {
+ ast::type::F32Type f32;
+
+ auto in_var = std::make_unique<ast::Variable>(
+ "in_var", ast::StorageClass::kInput, &f32);
+ auto out_var = std::make_unique<ast::Variable>(
+ "out_var", ast::StorageClass::kOutput, &f32);
+ auto sb_var = std::make_unique<ast::Variable>(
+ "sb_var", ast::StorageClass::kStorageBuffer, &f32);
+ auto wg_var = std::make_unique<ast::Variable>(
+ "wg_var", ast::StorageClass::kWorkgroup, &f32);
+ auto priv_var = std::make_unique<ast::Variable>(
+ "priv_var", ast::StorageClass::kPrivate, &f32);
+
+ auto* in_ptr = in_var.get();
+ auto* out_ptr = out_var.get();
+ auto* sb_ptr = sb_var.get();
+ auto* wg_ptr = wg_var.get();
+ auto* priv_ptr = priv_var.get();
+
+ mod()->AddGlobalVariable(std::move(in_var));
+ mod()->AddGlobalVariable(std::move(out_var));
+ mod()->AddGlobalVariable(std::move(sb_var));
+ mod()->AddGlobalVariable(std::move(wg_var));
+ mod()->AddGlobalVariable(std::move(priv_var));
+
+ ast::VariableList params;
+ auto func =
+ std::make_unique<ast::Function>("my_func", std::move(params), &f32);
+
+ ast::StatementList body;
+ body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("out_var"),
+ std::make_unique<ast::IdentifierExpression>("in_var")));
+ body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("wg_var"),
+ std::make_unique<ast::IdentifierExpression>("wg_var")));
+ body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("sb_var"),
+ std::make_unique<ast::IdentifierExpression>("sb_var")));
+ body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("priv_var"),
+ std::make_unique<ast::IdentifierExpression>("priv_var")));
+ func->set_body(std::move(body));
+
+ mod()->AddFunction(std::move(func));
+
+ auto func2 = std::make_unique<ast::Function>("func", std::move(params), &f32);
+ auto* func2_ptr = func2.get();
+ body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("out_var"),
+ std::make_unique<ast::CallExpression>(
+ std::make_unique<ast::IdentifierExpression>("my_func"),
+ ast::ExpressionList{})));
+ func2->set_body(std::move(body));
+
+ mod()->AddFunction(std::move(func2));
+
+ // Register the function
+ EXPECT_TRUE(td()->Determine());
+
+ const auto& vars = func2_ptr->referenced_module_variables();
+ ASSERT_EQ(vars.size(), 5u);
+ EXPECT_EQ(vars[0], out_ptr);
+ EXPECT_EQ(vars[1], in_ptr);
+ EXPECT_EQ(vars[2], wg_ptr);
+ EXPECT_EQ(vars[3], sb_ptr);
+ EXPECT_EQ(vars[4], priv_ptr);
+}
+
TEST_F(TypeDeterminerTest, Function_NotRegisterFunctionVariable) {
ast::type::F32Type f32;