[spirv-writer] Only add used variables to entry point.
This Cl updates the entry point code to only output Input/Output
variabes which are referenced by the function instead of all
Input/Output variables.
Bug: tint:28
Change-Id: Idc429e02cac8dac7fc7b609cbd7f88039695829e
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23623
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/ast/function.cc b/src/ast/function.cc
index 3b1e839..f166031 100644
--- a/src/ast/function.cc
+++ b/src/ast/function.cc
@@ -42,6 +42,15 @@
Function::~Function() = default;
+void Function::add_referenced_module_variable(Variable* var) {
+ for (const auto* v : referenced_module_vars_) {
+ if (v->name() == var->name()) {
+ return;
+ }
+ }
+ referenced_module_vars_.push_back(var);
+}
+
bool Function::IsValid() const {
for (const auto& param : params_) {
if (param == nullptr || !param->IsValid())
diff --git a/src/ast/function.h b/src/ast/function.h
index c6ee88c..fa26849 100644
--- a/src/ast/function.h
+++ b/src/ast/function.h
@@ -68,6 +68,15 @@
/// @returns the function params
const VariableList& params() const { return params_; }
+ /// Adds the given variable to the list of referenced module variables if it
+ /// is not already included.
+ /// @param var the module variable to add
+ void add_referenced_module_variable(Variable* var);
+ /// @returns the referenced module variables
+ const std::vector<Variable*>& referenced_module_variables() const {
+ return referenced_module_vars_;
+ }
+
/// Sets the return type of the function
/// @param type the return type
void set_return_type(type::Type* type) { return_type_ = type; }
@@ -98,6 +107,7 @@
VariableList params_;
type::Type* return_type_ = nullptr;
StatementList body_;
+ std::vector<Variable*> referenced_module_vars_;
};
/// A list of unique functions
diff --git a/src/ast/function_test.cc b/src/ast/function_test.cc
index 300244c..5b0c394 100644
--- a/src/ast/function_test.cc
+++ b/src/ast/function_test.cc
@@ -57,6 +57,26 @@
EXPECT_EQ(src.column, 2u);
}
+TEST_F(FunctionTest, AddDuplicateReferencedVariables) {
+ type::VoidType void_type;
+ type::I32Type i32;
+
+ Variable v("var", StorageClass::kInput, &i32);
+ Function f("func", VariableList{}, &void_type);
+
+ f.add_referenced_module_variable(&v);
+ ASSERT_EQ(f.referenced_module_variables().size(), 1u);
+ EXPECT_EQ(f.referenced_module_variables()[0], &v);
+
+ f.add_referenced_module_variable(&v);
+ ASSERT_EQ(f.referenced_module_variables().size(), 1u);
+
+ Variable v2("var2", StorageClass::kOutput, &i32);
+ f.add_referenced_module_variable(&v2);
+ ASSERT_EQ(f.referenced_module_variables().size(), 2u);
+ EXPECT_EQ(f.referenced_module_variables()[1], &v2);
+}
+
TEST_F(FunctionTest, IsValid) {
type::VoidType void_type;
type::I32Type i32;
diff --git a/src/ast/variable.h b/src/ast/variable.h
index b84ef82..f940363 100644
--- a/src/ast/variable.h
+++ b/src/ast/variable.h
@@ -105,7 +105,7 @@
/// @param name the name to set
void set_name(const std::string& name) { name_ = name; }
/// @returns the variable name
- const std::string& name() { return name_; }
+ const std::string& name() const { return name_; }
/// Sets the value type if a const or formal parameter, or the
/// store type if a var.
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index 23de278..fa6a03c 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -161,6 +161,19 @@
error_ += msg;
}
+void TypeDeterminer::set_referenced_from_function_if_needed(
+ ast::Variable* var) {
+ if (current_function_ == nullptr) {
+ return;
+ }
+ if (var->storage_class() == ast::StorageClass::kNone ||
+ var->storage_class() == ast::StorageClass::kFunction) {
+ return;
+ }
+
+ current_function_->add_referenced_module_variable(var);
+}
+
bool TypeDeterminer::Determine() {
for (const auto& var : mod_->global_variables()) {
variable_stack_.set_global(var->name(), var.get());
@@ -190,6 +203,8 @@
bool TypeDeterminer::DetermineFunction(ast::Function* func) {
name_to_function_[func->name()] = func;
+ current_function_ = func;
+
variable_stack_.push_scope();
for (const auto& param : func->params()) {
variable_stack_.set(param->name(), param.get());
@@ -200,6 +215,8 @@
}
variable_stack_.pop_scope();
+ current_function_ = nullptr;
+
return true;
}
@@ -567,6 +584,8 @@
ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
var->type(), var->storage_class())));
}
+
+ set_referenced_from_function_if_needed(var);
return true;
}
diff --git a/src/type_determiner.h b/src/type_determiner.h
index 7f62a57..80f2397 100644
--- a/src/type_determiner.h
+++ b/src/type_determiner.h
@@ -104,6 +104,7 @@
private:
void set_error(const Source& src, const std::string& msg);
+ void set_referenced_from_function_if_needed(ast::Variable* var);
bool DetermineArrayAccessor(ast::ArrayAccessorExpression* expr);
bool DetermineAs(ast::AsExpression* expr);
@@ -121,6 +122,7 @@
std::string error_;
ScopeStack<ast::Variable*> variable_stack_;
std::unordered_map<std::string, ast::Function*> name_to_function_;
+ ast::Function* current_function_ = nullptr;
};
} // namespace tint
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index 15fd077..319bfdf 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -743,6 +743,93 @@
EXPECT_TRUE(ident.result_type()->IsF32());
}
+TEST_F(TypeDeterminerTest, Function_RegisterInputOutputVariables) {
+ ast::type::F32Type f32;
+
+ auto in_var = std::make_unique<ast::Variable>(
+ "in_var", ast::StorageClass::kInput, &f32);
+ auto out_var = std::make_unique<ast::Variable>(
+ "out_var", ast::StorageClass::kOutput, &f32);
+ auto sb_var = std::make_unique<ast::Variable>(
+ "sb_var", ast::StorageClass::kStorageBuffer, &f32);
+ auto wg_var = std::make_unique<ast::Variable>(
+ "wg_var", ast::StorageClass::kWorkgroup, &f32);
+ auto priv_var = std::make_unique<ast::Variable>(
+ "priv_var", ast::StorageClass::kPrivate, &f32);
+
+ auto in_ptr = in_var.get();
+ auto out_ptr = out_var.get();
+ auto sb_ptr = sb_var.get();
+ auto wg_ptr = wg_var.get();
+ auto priv_ptr = priv_var.get();
+
+ mod()->AddGlobalVariable(std::move(in_var));
+ mod()->AddGlobalVariable(std::move(out_var));
+ mod()->AddGlobalVariable(std::move(sb_var));
+ mod()->AddGlobalVariable(std::move(wg_var));
+ mod()->AddGlobalVariable(std::move(priv_var));
+
+ ast::VariableList params;
+ auto func =
+ std::make_unique<ast::Function>("my_func", std::move(params), &f32);
+ auto func_ptr = func.get();
+
+ ast::StatementList body;
+ body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("out_var"),
+ std::make_unique<ast::IdentifierExpression>("in_var")));
+ body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("wg_var"),
+ std::make_unique<ast::IdentifierExpression>("wg_var")));
+ body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("sb_var"),
+ std::make_unique<ast::IdentifierExpression>("sb_var")));
+ body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("priv_var"),
+ std::make_unique<ast::IdentifierExpression>("priv_var")));
+ func->set_body(std::move(body));
+
+ mod()->AddFunction(std::move(func));
+
+ // Register the function
+ EXPECT_TRUE(td()->Determine());
+
+ const auto& vars = func_ptr->referenced_module_variables();
+ ASSERT_EQ(vars.size(), 5);
+ EXPECT_EQ(vars[0], out_ptr);
+ EXPECT_EQ(vars[1], in_ptr);
+ EXPECT_EQ(vars[2], wg_ptr);
+ EXPECT_EQ(vars[3], sb_ptr);
+ EXPECT_EQ(vars[4], priv_ptr);
+}
+
+TEST_F(TypeDeterminerTest, Function_NotRegisterFunctionVariable) {
+ ast::type::F32Type f32;
+
+ auto var = std::make_unique<ast::Variable>(
+ "in_var", ast::StorageClass::kFunction, &f32);
+
+ ast::VariableList params;
+ auto func =
+ std::make_unique<ast::Function>("my_func", std::move(params), &f32);
+ auto func_ptr = func.get();
+
+ ast::StatementList body;
+ body.push_back(std::make_unique<ast::VariableDeclStatement>(std::move(var)));
+ body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("var"),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.f))));
+ func->set_body(std::move(body));
+
+ mod()->AddFunction(std::move(func));
+
+ // Register the function
+ EXPECT_TRUE(td()->Determine());
+
+ EXPECT_EQ(func_ptr->referenced_module_variables().size(), 0);
+}
+
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_Struct) {
ast::type::I32Type i32;
ast::type::F32Type f32;
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 0a81e8e..93f3133 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -164,6 +164,8 @@
}
}
+ // Note, the entry points must be generated after the functions as they need
+ // to be able to lookup the function information based on the name.
for (const auto& ep : mod_->entry_points()) {
if (!GenerateEntryPoint(ep.get())) {
return false;
@@ -296,10 +298,16 @@
OperandList operands = {Operand::Int(stage), Operand::Int(id),
Operand::String(name)};
- // TODO(dsinclair): This could be made smarter by only listing the
- // input/output variables which are used by the entry point instead of just
- // listing all module scoped variables of type input/output.
- for (const auto& var : mod_->global_variables()) {
+
+ auto* func = func_name_to_func_[ep->function_name()];
+ if (func == nullptr) {
+ error_ = "processing an entry point when the function has not been seen.";
+ return false;
+ }
+
+ for (const auto* var : func->referenced_module_variables()) {
+ // For SPIR-V 1.3 we only output Input/output variables. If we update to
+ // SPIR-V 1.4 or later this should be all variables.
if (var->storage_class() != ast::StorageClass::kInput &&
var->storage_class() != ast::StorageClass::kOutput) {
continue;
@@ -425,6 +433,7 @@
scope_stack_.pop_scope();
func_name_to_id_[func->name()] = func_id;
+ func_name_to_func_[func->name()] = func;
return true;
}
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index 4213462..91505e8 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -84,36 +84,6 @@
return id;
}
- /// Sets the id for a given function name
- /// @param name the name to set
- /// @param id the id to set
- void set_func_name_to_id(const std::string& name, uint32_t id) {
- func_name_to_id_[name] = id;
- }
-
- /// Retrives the id for the given function name
- /// @param name the function name to search for
- /// @returns the id for the given name or 0 on failure
- uint32_t id_for_func_name(const std::string& name) {
- if (func_name_to_id_.count(name) == 0) {
- return 0;
- }
- return func_name_to_id_[name];
- }
-
- /// Retrieves the id for an entry point function, or 0 if not found.
- /// Emits an error if not found.
- /// @param ep the entry point
- /// @returns 0 on error
- uint32_t id_for_entry_point(ast::EntryPoint* ep) {
- auto id = id_for_func_name(ep->function_name());
- if (id == 0) {
- error_ = "unable to find ID for function: " + ep->function_name();
- return 0;
- }
- return id;
- }
-
/// Iterates over all the instructions in the correct order and calls the
/// given callback
/// @param cb the callback to execute
@@ -402,6 +372,29 @@
/// automatically.
Operand result_op();
+ /// Retrives the id for the given function name
+ /// @param name the function name to search for
+ /// @returns the id for the given name or 0 on failure
+ uint32_t id_for_func_name(const std::string& name) {
+ if (func_name_to_id_.count(name) == 0) {
+ return 0;
+ }
+ return func_name_to_id_[name];
+ }
+
+ /// Retrieves the id for an entry point function, or 0 if not found.
+ /// Emits an error if not found.
+ /// @param ep the entry point
+ /// @returns 0 on error
+ uint32_t id_for_entry_point(ast::EntryPoint* ep) {
+ auto id = id_for_func_name(ep->function_name());
+ if (id == 0) {
+ error_ = "unable to find ID for function: " + ep->function_name();
+ return 0;
+ }
+ return id;
+ }
+
ast::Module* mod_;
std::string error_;
uint32_t next_id_ = 1;
@@ -415,6 +408,7 @@
std::unordered_map<std::string, uint32_t> import_name_to_id_;
std::unordered_map<std::string, uint32_t> func_name_to_id_;
+ std::unordered_map<std::string, ast::Function*> func_name_to_func_;
std::unordered_map<std::string, uint32_t> type_name_to_id_;
std::unordered_map<std::string, uint32_t> const_to_id_;
ScopeStack<uint32_t> scope_stack_;
diff --git a/src/writer/spirv/builder_entry_point_test.cc b/src/writer/spirv/builder_entry_point_test.cc
index 2f1d8e0..34dc0f8 100644
--- a/src/writer/spirv/builder_entry_point_test.cc
+++ b/src/writer/spirv/builder_entry_point_test.cc
@@ -17,10 +17,16 @@
#include "gtest/gtest.h"
#include "spirv/unified1/spirv.h"
#include "spirv/unified1/spirv.hpp11"
+#include "src/ast/assignment_statement.h"
#include "src/ast/entry_point.h"
+#include "src/ast/function.h"
+#include "src/ast/identifier_expression.h"
#include "src/ast/pipeline_stage.h"
#include "src/ast/type/f32_type.h"
+#include "src/ast/type/void_type.h"
#include "src/ast/variable.h"
+#include "src/context.h"
+#include "src/type_determiner.h"
#include "src/writer/spirv/builder.h"
#include "src/writer/spirv/spv_dump.h"
@@ -32,24 +38,30 @@
using BuilderTest = testing::Test;
TEST_F(BuilderTest, EntryPoint) {
+ ast::type::VoidType void_type;
+
+ ast::Function func("frag_main", {}, &void_type);
ast::EntryPoint ep(ast::PipelineStage::kFragment, "main", "frag_main");
ast::Module mod;
Builder b(&mod);
- b.set_func_name_to_id("frag_main", 2);
- ASSERT_TRUE(b.GenerateEntryPoint(&ep));
+ ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
+ ASSERT_TRUE(b.GenerateEntryPoint(&ep)) << b.error();
- EXPECT_EQ(DumpInstructions(b.preamble()), R"(OpEntryPoint Fragment %2 "main"
+ EXPECT_EQ(DumpInstructions(b.preamble()), R"(OpEntryPoint Fragment %3 "main"
)");
}
TEST_F(BuilderTest, EntryPoint_WithoutName) {
+ ast::type::VoidType void_type;
+
+ ast::Function func("compute_main", {}, &void_type);
ast::EntryPoint ep(ast::PipelineStage::kCompute, "", "compute_main");
ast::Module mod;
Builder b(&mod);
- b.set_func_name_to_id("compute_main", 3);
- ASSERT_TRUE(b.GenerateEntryPoint(&ep));
+ ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
+ ASSERT_TRUE(b.GenerateEntryPoint(&ep)) << b.error();
EXPECT_EQ(DumpInstructions(b.preamble()),
R"(OpEntryPoint GLCompute %3 "compute_main"
@@ -77,12 +89,15 @@
TEST_P(EntryPointStageTest, Emit) {
auto params = GetParam();
+ ast::type::VoidType void_type;
+
+ ast::Function func("main", {}, &void_type);
ast::EntryPoint ep(params.stage, "", "main");
ast::Module mod;
Builder b(&mod);
- b.set_func_name_to_id("main", 3);
- ASSERT_TRUE(b.GenerateEntryPoint(&ep));
+ ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
+ ASSERT_TRUE(b.GenerateEntryPoint(&ep)) << b.error();
auto preamble = b.preamble();
ASSERT_EQ(preamble.size(), 1u);
@@ -101,8 +116,12 @@
EntryPointStageData{ast::PipelineStage::kCompute,
SpvExecutionModelGLCompute}));
-TEST_F(BuilderTest, EntryPoint_WithInterfaceIds) {
+TEST_F(BuilderTest, EntryPoint_WithUnusedInterfaceIds) {
ast::type::F32Type f32;
+ ast::type::VoidType void_type;
+
+ ast::Function func("main", {}, &void_type);
+
auto v_in =
std::make_unique<ast::Variable>("my_in", ast::StorageClass::kInput, &f32);
auto v_out = std::make_unique<ast::Variable>(
@@ -121,11 +140,12 @@
mod.AddGlobalVariable(std::move(v_out));
mod.AddGlobalVariable(std::move(v_wg));
- b.set_func_name_to_id("main", 3);
- ASSERT_TRUE(b.GenerateEntryPoint(&ep));
+ ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
+ ASSERT_TRUE(b.GenerateEntryPoint(&ep)) << b.error();
EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %1 "my_in"
OpName %4 "my_out"
OpName %7 "my_wg"
+OpName %11 "main"
)");
EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32
%2 = OpTypePointer Input %3
@@ -135,35 +155,111 @@
%4 = OpVariable %5 Output %6
%8 = OpTypePointer Workgroup %3
%7 = OpVariable %8 Workgroup
+%10 = OpTypeVoid
+%9 = OpTypeFunction %10
)");
EXPECT_EQ(DumpInstructions(b.preamble()),
- R"(OpEntryPoint Vertex %3 "main" %1 %4
+ R"(OpEntryPoint Vertex %11 "main"
+)");
+}
+
+TEST_F(BuilderTest, EntryPoint_WithUsedInterfaceIds) {
+ ast::type::F32Type f32;
+ ast::type::VoidType void_type;
+
+ ast::Function func("main", {}, &void_type);
+ ast::StatementList body;
+ body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("my_out"),
+ std::make_unique<ast::IdentifierExpression>("my_in")));
+ body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("my_wg"),
+ std::make_unique<ast::IdentifierExpression>("my_wg")));
+ // Add duplicate usages so we show they don't get output multiple times.
+ body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("my_out"),
+ std::make_unique<ast::IdentifierExpression>("my_in")));
+ func.set_body(std::move(body));
+
+ auto v_in =
+ std::make_unique<ast::Variable>("my_in", ast::StorageClass::kInput, &f32);
+ auto v_out = std::make_unique<ast::Variable>(
+ "my_out", ast::StorageClass::kOutput, &f32);
+ auto v_wg = std::make_unique<ast::Variable>(
+ "my_wg", ast::StorageClass::kWorkgroup, &f32);
+ ast::EntryPoint ep(ast::PipelineStage::kVertex, "", "main");
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ td.RegisterVariableForTesting(v_in.get());
+ td.RegisterVariableForTesting(v_out.get());
+ td.RegisterVariableForTesting(v_wg.get());
+
+ ASSERT_TRUE(td.DetermineFunction(&func)) << td.error();
+
+ Builder b(&mod);
+
+ EXPECT_TRUE(b.GenerateGlobalVariable(v_in.get())) << b.error();
+ EXPECT_TRUE(b.GenerateGlobalVariable(v_out.get())) << b.error();
+ EXPECT_TRUE(b.GenerateGlobalVariable(v_wg.get())) << b.error();
+
+ mod.AddGlobalVariable(std::move(v_in));
+ mod.AddGlobalVariable(std::move(v_out));
+ mod.AddGlobalVariable(std::move(v_wg));
+
+ ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
+ ASSERT_TRUE(b.GenerateEntryPoint(&ep)) << b.error();
+ EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %1 "my_in"
+OpName %4 "my_out"
+OpName %7 "my_wg"
+OpName %11 "main"
+)");
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32
+%2 = OpTypePointer Input %3
+%1 = OpVariable %2 Input
+%5 = OpTypePointer Output %3
+%6 = OpConstantNull %3
+%4 = OpVariable %5 Output %6
+%8 = OpTypePointer Workgroup %3
+%7 = OpVariable %8 Workgroup
+%10 = OpTypeVoid
+%9 = OpTypeFunction %10
+)");
+ EXPECT_EQ(DumpInstructions(b.preamble()),
+ R"(OpEntryPoint Vertex %11 "main" %4 %1
)");
}
TEST_F(BuilderTest, ExecutionModel_Fragment_OriginUpperLeft) {
+ ast::type::VoidType void_type;
+
+ ast::Function func("frag_main", {}, &void_type);
ast::EntryPoint ep(ast::PipelineStage::kFragment, "main", "frag_main");
ast::Module mod;
Builder b(&mod);
- b.set_func_name_to_id("frag_main", 2);
+ ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
ASSERT_TRUE(b.GenerateExecutionModes(&ep));
EXPECT_EQ(DumpInstructions(b.preamble()),
- R"(OpExecutionMode %2 OriginUpperLeft
+ R"(OpExecutionMode %3 OriginUpperLeft
)");
}
TEST_F(BuilderTest, ExecutionModel_Compute_LocalSize) {
+ ast::type::VoidType void_type;
+
+ ast::Function func("main", {}, &void_type);
ast::EntryPoint ep(ast::PipelineStage::kCompute, "main", "main");
ast::Module mod;
Builder b(&mod);
- b.set_func_name_to_id("main", 2);
+ ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
ASSERT_TRUE(b.GenerateExecutionModes(&ep));
EXPECT_EQ(DumpInstructions(b.preamble()),
- R"(OpExecutionMode %2 LocalSize 1 1 1
+ R"(OpExecutionMode %3 LocalSize 1 1 1
)");
}