[msl-writer] Only emit in/out structs if function call requires.
This CL updates the EmitCall method to only ouptut the input and output
struct name if there are in/out variables used in the function.
Bug: tint:107
Change-Id: Ic0c7722c8796c2f9baa3515cb46be0568f9e1ac3
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/25400
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index f826213..4a4119e 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -420,23 +420,6 @@
if (it != ep_func_name_remapped_.end()) {
name = it->second;
}
- out_ << name << "(";
-
- bool first = true;
- auto var_name = current_ep_var_name(VarType::kIn);
- if (!var_name.empty()) {
- out_ << var_name;
- first = false;
- }
-
- var_name = current_ep_var_name(VarType::kOut);
- if (!var_name.empty()) {
- if (!first) {
- out_ << ", ";
- }
- first = false;
- out_ << var_name;
- }
auto* func = module_->FindFunctionByName(ident->name());
if (func == nullptr) {
@@ -444,6 +427,27 @@
return false;
}
+ out_ << name << "(";
+
+ bool first = true;
+ if (has_referenced_in_var_needing_struct(func)) {
+ auto var_name = current_ep_var_name(VarType::kIn);
+ if (!var_name.empty()) {
+ out_ << var_name;
+ first = false;
+ }
+ }
+ if (has_referenced_out_var_needing_struct(func)) {
+ auto var_name = current_ep_var_name(VarType::kOut);
+ if (!var_name.empty()) {
+ if (!first) {
+ out_ << ", ";
+ }
+ first = false;
+ out_ << var_name;
+ }
+ }
+
for (const auto& data : func->referenced_builtin_variables()) {
auto* var = data.first;
if (var->storage_class() != ast::StorageClass::kInput) {
@@ -945,11 +949,20 @@
return;
}
-bool GeneratorImpl::has_referenced_var_needing_struct(ast::Function* func) {
+bool GeneratorImpl::has_referenced_in_var_needing_struct(ast::Function* func) {
for (auto data : func->referenced_location_variables()) {
auto* var = data.first;
- if (var->storage_class() == ast::StorageClass::kInput ||
- var->storage_class() == ast::StorageClass::kOutput) {
+ if (var->storage_class() == ast::StorageClass::kInput) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool GeneratorImpl::has_referenced_out_var_needing_struct(ast::Function* func) {
+ for (auto data : func->referenced_location_variables()) {
+ auto* var = data.first;
+ if (var->storage_class() == ast::StorageClass::kOutput) {
return true;
}
}
@@ -960,10 +973,14 @@
return true;
}
}
-
return false;
}
+bool GeneratorImpl::has_referenced_var_needing_struct(ast::Function* func) {
+ return has_referenced_in_var_needing_struct(func) ||
+ has_referenced_out_var_needing_struct(func);
+}
+
bool GeneratorImpl::EmitFunction(ast::Function* func) {
make_indent();
diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h
index 3176394..e2a0e38 100644
--- a/src/writer/msl/generator_impl.h
+++ b/src/writer/msl/generator_impl.h
@@ -207,6 +207,14 @@
/// @returns true if the zero value was successfully emitted.
bool EmitZeroValue(ast::type::Type* type);
+ /// Determines if the function needs the input struct passed to it.
+ /// @param func the function to check
+ /// @returns true if there are input struct variables used in the function
+ bool has_referenced_in_var_needing_struct(ast::Function* func);
+ /// Determines if the function needs the output struct passed to it.
+ /// @param func the function to check
+ /// @returns true if there are output struct variables used in the function
+ bool has_referenced_out_var_needing_struct(ast::Function* func);
/// Determines if any used module variable requires an input or output struct.
/// @param func the function to check
/// @returns true if an input or output struct is required.
diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc
index 297f6cd..0fb8de8 100644
--- a/src/writer/msl/generator_impl_function_test.cc
+++ b/src/writer/msl/generator_impl_function_test.cc
@@ -503,6 +503,84 @@
}
TEST_F(MslGeneratorImplTest,
+ Emit_Function_Called_By_EntryPoints_NoUsedGlobals) {
+ ast::type::VoidType void_type;
+ ast::type::F32Type f32;
+ ast::type::VectorType vec4(&f32, 4);
+
+ auto depth_var =
+ std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+ "depth", ast::StorageClass::kOutput, &f32));
+
+ ast::VariableDecorationList decos;
+ decos.push_back(
+ std::make_unique<ast::BuiltinDecoration>(ast::Builtin::kFragDepth));
+ depth_var->set_decorations(std::move(decos));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ td.RegisterVariableForTesting(depth_var.get());
+
+ mod.AddGlobalVariable(std::move(depth_var));
+
+ ast::VariableList params;
+ params.push_back(std::make_unique<ast::Variable>(
+ "param", ast::StorageClass::kFunction, &f32));
+ auto sub_func =
+ std::make_unique<ast::Function>("sub_func", std::move(params), &f32);
+
+ ast::StatementList body;
+ body.push_back(std::make_unique<ast::ReturnStatement>(
+ std::make_unique<ast::IdentifierExpression>("param")));
+ sub_func->set_body(std::move(body));
+
+ mod.AddFunction(std::move(sub_func));
+
+ auto func_1 = std::make_unique<ast::Function>("frag_1_main",
+ std::move(params), &void_type);
+
+ ast::ExpressionList expr;
+ expr.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("depth"),
+ std::make_unique<ast::CallExpression>(
+ std::make_unique<ast::IdentifierExpression>("sub_func"),
+ std::move(expr))));
+ body.push_back(std::make_unique<ast::ReturnStatement>());
+ func_1->set_body(std::move(body));
+
+ mod.AddFunction(std::move(func_1));
+
+ auto ep1 = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment,
+ "ep_1", "frag_1_main");
+ mod.AddEntryPoint(std::move(ep1));
+
+ ASSERT_TRUE(td.Determine()) << td.error();
+
+ GeneratorImpl g(&mod);
+ ASSERT_TRUE(g.Generate()) << g.error();
+ EXPECT_EQ(g.result(), R"(#include <metal_stdlib>
+
+struct ep_1_out {
+ float depth [[depth(any)]];
+};
+
+float sub_func(float param) {
+ return param;
+}
+
+fragment ep_1_out ep_1() {
+ ep_1_out tint_out = {};
+ tint_out.depth = sub_func(1.00000000f);
+ return tint_out;
+}
+
+)");
+}
+
+TEST_F(MslGeneratorImplTest,
Emit_Function_Called_By_EntryPoints_WithBuiltinGlobals_And_Params) {
ast::type::VoidType void_type;
ast::type::F32Type f32;