Inherit refererenced globals up the call stack. A given caller should inherit the globals referenced from a callee. This way, a given entry point will have a list of all the variables used up the stack which it needs to reference. Change-Id: Ib6efcdd5c3347749ad2d54aecfa425bd966a62fd Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/24762 Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/type_determiner.cc b/src/type_determiner.cc index 82d45a4..6b1d139 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc
@@ -477,6 +477,18 @@ } else { if (current_function_) { caller_to_callee_[current_function_->name()].push_back(ident->name()); + + auto* callee_func = mod_->FindFunctionByName(ident->name()); + if (callee_func == nullptr) { + set_error(expr->source(), + "unable to find called function: " + ident->name()); + return false; + } + + // We inherit any referenced variables from the callee. + for (auto* var : callee_func->referenced_module_variables()) { + set_referenced_from_function_if_needed(var); + } } // An identifier with a single name is a function call, not an import
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index 183bbc2..983ac36 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc
@@ -803,6 +803,76 @@ EXPECT_EQ(vars[4], priv_ptr); } +TEST_F(TypeDeterminerTest, Function_RegisterInputOutputVariables_SubFunction) { + ast::type::F32Type f32; + + auto in_var = std::make_unique<ast::Variable>( + "in_var", ast::StorageClass::kInput, &f32); + auto out_var = std::make_unique<ast::Variable>( + "out_var", ast::StorageClass::kOutput, &f32); + auto sb_var = std::make_unique<ast::Variable>( + "sb_var", ast::StorageClass::kStorageBuffer, &f32); + auto wg_var = std::make_unique<ast::Variable>( + "wg_var", ast::StorageClass::kWorkgroup, &f32); + auto priv_var = std::make_unique<ast::Variable>( + "priv_var", ast::StorageClass::kPrivate, &f32); + + auto* in_ptr = in_var.get(); + auto* out_ptr = out_var.get(); + auto* sb_ptr = sb_var.get(); + auto* wg_ptr = wg_var.get(); + auto* priv_ptr = priv_var.get(); + + mod()->AddGlobalVariable(std::move(in_var)); + mod()->AddGlobalVariable(std::move(out_var)); + mod()->AddGlobalVariable(std::move(sb_var)); + mod()->AddGlobalVariable(std::move(wg_var)); + mod()->AddGlobalVariable(std::move(priv_var)); + + ast::VariableList params; + auto func = + std::make_unique<ast::Function>("my_func", std::move(params), &f32); + + ast::StatementList body; + body.push_back(std::make_unique<ast::AssignmentStatement>( + std::make_unique<ast::IdentifierExpression>("out_var"), + std::make_unique<ast::IdentifierExpression>("in_var"))); + body.push_back(std::make_unique<ast::AssignmentStatement>( + std::make_unique<ast::IdentifierExpression>("wg_var"), + std::make_unique<ast::IdentifierExpression>("wg_var"))); + body.push_back(std::make_unique<ast::AssignmentStatement>( + std::make_unique<ast::IdentifierExpression>("sb_var"), + std::make_unique<ast::IdentifierExpression>("sb_var"))); + body.push_back(std::make_unique<ast::AssignmentStatement>( + std::make_unique<ast::IdentifierExpression>("priv_var"), + std::make_unique<ast::IdentifierExpression>("priv_var"))); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + + auto func2 = std::make_unique<ast::Function>("func", std::move(params), &f32); + auto* func2_ptr = func2.get(); + body.push_back(std::make_unique<ast::AssignmentStatement>( + std::make_unique<ast::IdentifierExpression>("out_var"), + std::make_unique<ast::CallExpression>( + std::make_unique<ast::IdentifierExpression>("my_func"), + ast::ExpressionList{}))); + func2->set_body(std::move(body)); + + mod()->AddFunction(std::move(func2)); + + // Register the function + EXPECT_TRUE(td()->Determine()); + + const auto& vars = func2_ptr->referenced_module_variables(); + ASSERT_EQ(vars.size(), 5u); + EXPECT_EQ(vars[0], out_ptr); + EXPECT_EQ(vars[1], in_ptr); + EXPECT_EQ(vars[2], wg_ptr); + EXPECT_EQ(vars[3], sb_ptr); + EXPECT_EQ(vars[4], priv_ptr); +} + TEST_F(TypeDeterminerTest, Function_NotRegisterFunctionVariable) { ast::type::F32Type f32;