Inherit refererenced globals up the call stack.

A given caller should inherit the globals referenced from a callee. This
way, a given entry point will have a list of all the variables used up
the stack which it needs to reference.

Change-Id: Ib6efcdd5c3347749ad2d54aecfa425bd966a62fd
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/24762
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index 82d45a4..6b1d139 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -477,6 +477,18 @@
     } else {
       if (current_function_) {
         caller_to_callee_[current_function_->name()].push_back(ident->name());
+
+        auto* callee_func = mod_->FindFunctionByName(ident->name());
+        if (callee_func == nullptr) {
+          set_error(expr->source(),
+                    "unable to find called function: " + ident->name());
+          return false;
+        }
+
+        // We inherit any referenced variables from the callee.
+        for (auto* var : callee_func->referenced_module_variables()) {
+          set_referenced_from_function_if_needed(var);
+        }
       }
 
       // An identifier with a single name is a function call, not an import
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index 183bbc2..983ac36 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -803,6 +803,76 @@
   EXPECT_EQ(vars[4], priv_ptr);
 }
 
+TEST_F(TypeDeterminerTest, Function_RegisterInputOutputVariables_SubFunction) {
+  ast::type::F32Type f32;
+
+  auto in_var = std::make_unique<ast::Variable>(
+      "in_var", ast::StorageClass::kInput, &f32);
+  auto out_var = std::make_unique<ast::Variable>(
+      "out_var", ast::StorageClass::kOutput, &f32);
+  auto sb_var = std::make_unique<ast::Variable>(
+      "sb_var", ast::StorageClass::kStorageBuffer, &f32);
+  auto wg_var = std::make_unique<ast::Variable>(
+      "wg_var", ast::StorageClass::kWorkgroup, &f32);
+  auto priv_var = std::make_unique<ast::Variable>(
+      "priv_var", ast::StorageClass::kPrivate, &f32);
+
+  auto* in_ptr = in_var.get();
+  auto* out_ptr = out_var.get();
+  auto* sb_ptr = sb_var.get();
+  auto* wg_ptr = wg_var.get();
+  auto* priv_ptr = priv_var.get();
+
+  mod()->AddGlobalVariable(std::move(in_var));
+  mod()->AddGlobalVariable(std::move(out_var));
+  mod()->AddGlobalVariable(std::move(sb_var));
+  mod()->AddGlobalVariable(std::move(wg_var));
+  mod()->AddGlobalVariable(std::move(priv_var));
+
+  ast::VariableList params;
+  auto func =
+      std::make_unique<ast::Function>("my_func", std::move(params), &f32);
+
+  ast::StatementList body;
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("out_var"),
+      std::make_unique<ast::IdentifierExpression>("in_var")));
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("wg_var"),
+      std::make_unique<ast::IdentifierExpression>("wg_var")));
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("sb_var"),
+      std::make_unique<ast::IdentifierExpression>("sb_var")));
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("priv_var"),
+      std::make_unique<ast::IdentifierExpression>("priv_var")));
+  func->set_body(std::move(body));
+
+  mod()->AddFunction(std::move(func));
+
+  auto func2 = std::make_unique<ast::Function>("func", std::move(params), &f32);
+  auto* func2_ptr = func2.get();
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("out_var"),
+      std::make_unique<ast::CallExpression>(
+          std::make_unique<ast::IdentifierExpression>("my_func"),
+          ast::ExpressionList{})));
+  func2->set_body(std::move(body));
+
+  mod()->AddFunction(std::move(func2));
+
+  // Register the function
+  EXPECT_TRUE(td()->Determine());
+
+  const auto& vars = func2_ptr->referenced_module_variables();
+  ASSERT_EQ(vars.size(), 5u);
+  EXPECT_EQ(vars[0], out_ptr);
+  EXPECT_EQ(vars[1], in_ptr);
+  EXPECT_EQ(vars[2], wg_ptr);
+  EXPECT_EQ(vars[3], sb_ptr);
+  EXPECT_EQ(vars[4], priv_ptr);
+}
+
 TEST_F(TypeDeterminerTest, Function_NotRegisterFunctionVariable) {
   ast::type::F32Type f32;