[spirv-writer] Only add used variables to entry point.

This Cl updates the entry point code to only output Input/Output
variabes which are referenced by the function instead of all
Input/Output variables.

Bug: tint:28
Change-Id: Idc429e02cac8dac7fc7b609cbd7f88039695829e
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23623
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/ast/function.cc b/src/ast/function.cc
index 3b1e839..f166031 100644
--- a/src/ast/function.cc
+++ b/src/ast/function.cc
@@ -42,6 +42,15 @@
 
 Function::~Function() = default;
 
+void Function::add_referenced_module_variable(Variable* var) {
+  for (const auto* v : referenced_module_vars_) {
+    if (v->name() == var->name()) {
+      return;
+    }
+  }
+  referenced_module_vars_.push_back(var);
+}
+
 bool Function::IsValid() const {
   for (const auto& param : params_) {
     if (param == nullptr || !param->IsValid())
diff --git a/src/ast/function.h b/src/ast/function.h
index c6ee88c..fa26849 100644
--- a/src/ast/function.h
+++ b/src/ast/function.h
@@ -68,6 +68,15 @@
   /// @returns the function params
   const VariableList& params() const { return params_; }
 
+  /// Adds the given variable to the list of referenced module variables if it
+  /// is not already included.
+  /// @param var the module variable to add
+  void add_referenced_module_variable(Variable* var);
+  /// @returns the referenced module variables
+  const std::vector<Variable*>& referenced_module_variables() const {
+    return referenced_module_vars_;
+  }
+
   /// Sets the return type of the function
   /// @param type the return type
   void set_return_type(type::Type* type) { return_type_ = type; }
@@ -98,6 +107,7 @@
   VariableList params_;
   type::Type* return_type_ = nullptr;
   StatementList body_;
+  std::vector<Variable*> referenced_module_vars_;
 };
 
 /// A list of unique functions
diff --git a/src/ast/function_test.cc b/src/ast/function_test.cc
index 300244c..5b0c394 100644
--- a/src/ast/function_test.cc
+++ b/src/ast/function_test.cc
@@ -57,6 +57,26 @@
   EXPECT_EQ(src.column, 2u);
 }
 
+TEST_F(FunctionTest, AddDuplicateReferencedVariables) {
+  type::VoidType void_type;
+  type::I32Type i32;
+
+  Variable v("var", StorageClass::kInput, &i32);
+  Function f("func", VariableList{}, &void_type);
+
+  f.add_referenced_module_variable(&v);
+  ASSERT_EQ(f.referenced_module_variables().size(), 1u);
+  EXPECT_EQ(f.referenced_module_variables()[0], &v);
+
+  f.add_referenced_module_variable(&v);
+  ASSERT_EQ(f.referenced_module_variables().size(), 1u);
+
+  Variable v2("var2", StorageClass::kOutput, &i32);
+  f.add_referenced_module_variable(&v2);
+  ASSERT_EQ(f.referenced_module_variables().size(), 2u);
+  EXPECT_EQ(f.referenced_module_variables()[1], &v2);
+}
+
 TEST_F(FunctionTest, IsValid) {
   type::VoidType void_type;
   type::I32Type i32;
diff --git a/src/ast/variable.h b/src/ast/variable.h
index b84ef82..f940363 100644
--- a/src/ast/variable.h
+++ b/src/ast/variable.h
@@ -105,7 +105,7 @@
   /// @param name the name to set
   void set_name(const std::string& name) { name_ = name; }
   /// @returns the variable name
-  const std::string& name() { return name_; }
+  const std::string& name() const { return name_; }
 
   /// Sets the value type if a const or formal parameter, or the
   /// store type if a var.
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index 23de278..fa6a03c 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -161,6 +161,19 @@
   error_ += msg;
 }
 
+void TypeDeterminer::set_referenced_from_function_if_needed(
+    ast::Variable* var) {
+  if (current_function_ == nullptr) {
+    return;
+  }
+  if (var->storage_class() == ast::StorageClass::kNone ||
+      var->storage_class() == ast::StorageClass::kFunction) {
+    return;
+  }
+
+  current_function_->add_referenced_module_variable(var);
+}
+
 bool TypeDeterminer::Determine() {
   for (const auto& var : mod_->global_variables()) {
     variable_stack_.set_global(var->name(), var.get());
@@ -190,6 +203,8 @@
 bool TypeDeterminer::DetermineFunction(ast::Function* func) {
   name_to_function_[func->name()] = func;
 
+  current_function_ = func;
+
   variable_stack_.push_scope();
   for (const auto& param : func->params()) {
     variable_stack_.set(param->name(), param.get());
@@ -200,6 +215,8 @@
   }
   variable_stack_.pop_scope();
 
+  current_function_ = nullptr;
+
   return true;
 }
 
@@ -567,6 +584,8 @@
           ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
               var->type(), var->storage_class())));
     }
+
+    set_referenced_from_function_if_needed(var);
     return true;
   }
 
diff --git a/src/type_determiner.h b/src/type_determiner.h
index 7f62a57..80f2397 100644
--- a/src/type_determiner.h
+++ b/src/type_determiner.h
@@ -104,6 +104,7 @@
 
  private:
   void set_error(const Source& src, const std::string& msg);
+  void set_referenced_from_function_if_needed(ast::Variable* var);
 
   bool DetermineArrayAccessor(ast::ArrayAccessorExpression* expr);
   bool DetermineAs(ast::AsExpression* expr);
@@ -121,6 +122,7 @@
   std::string error_;
   ScopeStack<ast::Variable*> variable_stack_;
   std::unordered_map<std::string, ast::Function*> name_to_function_;
+  ast::Function* current_function_ = nullptr;
 };
 
 }  // namespace tint
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index 15fd077..319bfdf 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -743,6 +743,93 @@
   EXPECT_TRUE(ident.result_type()->IsF32());
 }
 
+TEST_F(TypeDeterminerTest, Function_RegisterInputOutputVariables) {
+  ast::type::F32Type f32;
+
+  auto in_var = std::make_unique<ast::Variable>(
+      "in_var", ast::StorageClass::kInput, &f32);
+  auto out_var = std::make_unique<ast::Variable>(
+      "out_var", ast::StorageClass::kOutput, &f32);
+  auto sb_var = std::make_unique<ast::Variable>(
+      "sb_var", ast::StorageClass::kStorageBuffer, &f32);
+  auto wg_var = std::make_unique<ast::Variable>(
+      "wg_var", ast::StorageClass::kWorkgroup, &f32);
+  auto priv_var = std::make_unique<ast::Variable>(
+      "priv_var", ast::StorageClass::kPrivate, &f32);
+
+  auto in_ptr = in_var.get();
+  auto out_ptr = out_var.get();
+  auto sb_ptr = sb_var.get();
+  auto wg_ptr = wg_var.get();
+  auto priv_ptr = priv_var.get();
+
+  mod()->AddGlobalVariable(std::move(in_var));
+  mod()->AddGlobalVariable(std::move(out_var));
+  mod()->AddGlobalVariable(std::move(sb_var));
+  mod()->AddGlobalVariable(std::move(wg_var));
+  mod()->AddGlobalVariable(std::move(priv_var));
+
+  ast::VariableList params;
+  auto func =
+      std::make_unique<ast::Function>("my_func", std::move(params), &f32);
+  auto func_ptr = func.get();
+
+  ast::StatementList body;
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("out_var"),
+      std::make_unique<ast::IdentifierExpression>("in_var")));
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("wg_var"),
+      std::make_unique<ast::IdentifierExpression>("wg_var")));
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("sb_var"),
+      std::make_unique<ast::IdentifierExpression>("sb_var")));
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("priv_var"),
+      std::make_unique<ast::IdentifierExpression>("priv_var")));
+  func->set_body(std::move(body));
+
+  mod()->AddFunction(std::move(func));
+
+  // Register the function
+  EXPECT_TRUE(td()->Determine());
+
+  const auto& vars = func_ptr->referenced_module_variables();
+  ASSERT_EQ(vars.size(), 5);
+  EXPECT_EQ(vars[0], out_ptr);
+  EXPECT_EQ(vars[1], in_ptr);
+  EXPECT_EQ(vars[2], wg_ptr);
+  EXPECT_EQ(vars[3], sb_ptr);
+  EXPECT_EQ(vars[4], priv_ptr);
+}
+
+TEST_F(TypeDeterminerTest, Function_NotRegisterFunctionVariable) {
+  ast::type::F32Type f32;
+
+  auto var = std::make_unique<ast::Variable>(
+      "in_var", ast::StorageClass::kFunction, &f32);
+
+  ast::VariableList params;
+  auto func =
+      std::make_unique<ast::Function>("my_func", std::move(params), &f32);
+  auto func_ptr = func.get();
+
+  ast::StatementList body;
+  body.push_back(std::make_unique<ast::VariableDeclStatement>(std::move(var)));
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("var"),
+      std::make_unique<ast::ScalarConstructorExpression>(
+          std::make_unique<ast::FloatLiteral>(&f32, 1.f))));
+  func->set_body(std::move(body));
+
+  mod()->AddFunction(std::move(func));
+
+  // Register the function
+  EXPECT_TRUE(td()->Determine());
+
+  EXPECT_EQ(func_ptr->referenced_module_variables().size(), 0);
+}
+
 TEST_F(TypeDeterminerTest, Expr_MemberAccessor_Struct) {
   ast::type::I32Type i32;
   ast::type::F32Type f32;
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 0a81e8e..93f3133 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -164,6 +164,8 @@
     }
   }
 
+  // Note, the entry points must be generated after the functions as they need
+  // to be able to lookup the function information based on the name.
   for (const auto& ep : mod_->entry_points()) {
     if (!GenerateEntryPoint(ep.get())) {
       return false;
@@ -296,10 +298,16 @@
 
   OperandList operands = {Operand::Int(stage), Operand::Int(id),
                           Operand::String(name)};
-  // TODO(dsinclair): This could be made smarter by only listing the
-  // input/output variables which are used by the entry point instead of just
-  // listing all module scoped variables of type input/output.
-  for (const auto& var : mod_->global_variables()) {
+
+  auto* func = func_name_to_func_[ep->function_name()];
+  if (func == nullptr) {
+    error_ = "processing an entry point when the function has not been seen.";
+    return false;
+  }
+
+  for (const auto* var : func->referenced_module_variables()) {
+    // For SPIR-V 1.3 we only output Input/output variables. If we update to
+    // SPIR-V 1.4 or later this should be all variables.
     if (var->storage_class() != ast::StorageClass::kInput &&
         var->storage_class() != ast::StorageClass::kOutput) {
       continue;
@@ -425,6 +433,7 @@
   scope_stack_.pop_scope();
 
   func_name_to_id_[func->name()] = func_id;
+  func_name_to_func_[func->name()] = func;
   return true;
 }
 
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index 4213462..91505e8 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -84,36 +84,6 @@
     return id;
   }
 
-  /// Sets the id for a given function name
-  /// @param name the name to set
-  /// @param id the id to set
-  void set_func_name_to_id(const std::string& name, uint32_t id) {
-    func_name_to_id_[name] = id;
-  }
-
-  /// Retrives the id for the given function name
-  /// @param name the function name to search for
-  /// @returns the id for the given name or 0 on failure
-  uint32_t id_for_func_name(const std::string& name) {
-    if (func_name_to_id_.count(name) == 0) {
-      return 0;
-    }
-    return func_name_to_id_[name];
-  }
-
-  /// Retrieves the id for an entry point function, or 0 if not found.
-  /// Emits an error if not found.
-  /// @param ep the entry point
-  /// @returns 0 on error
-  uint32_t id_for_entry_point(ast::EntryPoint* ep) {
-    auto id = id_for_func_name(ep->function_name());
-    if (id == 0) {
-      error_ = "unable to find ID for function: " + ep->function_name();
-      return 0;
-    }
-    return id;
-  }
-
   /// Iterates over all the instructions in the correct order and calls the
   /// given callback
   /// @param cb the callback to execute
@@ -402,6 +372,29 @@
   /// automatically.
   Operand result_op();
 
+  /// Retrives the id for the given function name
+  /// @param name the function name to search for
+  /// @returns the id for the given name or 0 on failure
+  uint32_t id_for_func_name(const std::string& name) {
+    if (func_name_to_id_.count(name) == 0) {
+      return 0;
+    }
+    return func_name_to_id_[name];
+  }
+
+  /// Retrieves the id for an entry point function, or 0 if not found.
+  /// Emits an error if not found.
+  /// @param ep the entry point
+  /// @returns 0 on error
+  uint32_t id_for_entry_point(ast::EntryPoint* ep) {
+    auto id = id_for_func_name(ep->function_name());
+    if (id == 0) {
+      error_ = "unable to find ID for function: " + ep->function_name();
+      return 0;
+    }
+    return id;
+  }
+
   ast::Module* mod_;
   std::string error_;
   uint32_t next_id_ = 1;
@@ -415,6 +408,7 @@
 
   std::unordered_map<std::string, uint32_t> import_name_to_id_;
   std::unordered_map<std::string, uint32_t> func_name_to_id_;
+  std::unordered_map<std::string, ast::Function*> func_name_to_func_;
   std::unordered_map<std::string, uint32_t> type_name_to_id_;
   std::unordered_map<std::string, uint32_t> const_to_id_;
   ScopeStack<uint32_t> scope_stack_;
diff --git a/src/writer/spirv/builder_entry_point_test.cc b/src/writer/spirv/builder_entry_point_test.cc
index 2f1d8e0..34dc0f8 100644
--- a/src/writer/spirv/builder_entry_point_test.cc
+++ b/src/writer/spirv/builder_entry_point_test.cc
@@ -17,10 +17,16 @@
 #include "gtest/gtest.h"
 #include "spirv/unified1/spirv.h"
 #include "spirv/unified1/spirv.hpp11"
+#include "src/ast/assignment_statement.h"
 #include "src/ast/entry_point.h"
+#include "src/ast/function.h"
+#include "src/ast/identifier_expression.h"
 #include "src/ast/pipeline_stage.h"
 #include "src/ast/type/f32_type.h"
+#include "src/ast/type/void_type.h"
 #include "src/ast/variable.h"
+#include "src/context.h"
+#include "src/type_determiner.h"
 #include "src/writer/spirv/builder.h"
 #include "src/writer/spirv/spv_dump.h"
 
@@ -32,24 +38,30 @@
 using BuilderTest = testing::Test;
 
 TEST_F(BuilderTest, EntryPoint) {
+  ast::type::VoidType void_type;
+
+  ast::Function func("frag_main", {}, &void_type);
   ast::EntryPoint ep(ast::PipelineStage::kFragment, "main", "frag_main");
 
   ast::Module mod;
   Builder b(&mod);
-  b.set_func_name_to_id("frag_main", 2);
-  ASSERT_TRUE(b.GenerateEntryPoint(&ep));
+  ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
+  ASSERT_TRUE(b.GenerateEntryPoint(&ep)) << b.error();
 
-  EXPECT_EQ(DumpInstructions(b.preamble()), R"(OpEntryPoint Fragment %2 "main"
+  EXPECT_EQ(DumpInstructions(b.preamble()), R"(OpEntryPoint Fragment %3 "main"
 )");
 }
 
 TEST_F(BuilderTest, EntryPoint_WithoutName) {
+  ast::type::VoidType void_type;
+
+  ast::Function func("compute_main", {}, &void_type);
   ast::EntryPoint ep(ast::PipelineStage::kCompute, "", "compute_main");
 
   ast::Module mod;
   Builder b(&mod);
-  b.set_func_name_to_id("compute_main", 3);
-  ASSERT_TRUE(b.GenerateEntryPoint(&ep));
+  ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
+  ASSERT_TRUE(b.GenerateEntryPoint(&ep)) << b.error();
 
   EXPECT_EQ(DumpInstructions(b.preamble()),
             R"(OpEntryPoint GLCompute %3 "compute_main"
@@ -77,12 +89,15 @@
 TEST_P(EntryPointStageTest, Emit) {
   auto params = GetParam();
 
+  ast::type::VoidType void_type;
+
+  ast::Function func("main", {}, &void_type);
   ast::EntryPoint ep(params.stage, "", "main");
 
   ast::Module mod;
   Builder b(&mod);
-  b.set_func_name_to_id("main", 3);
-  ASSERT_TRUE(b.GenerateEntryPoint(&ep));
+  ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
+  ASSERT_TRUE(b.GenerateEntryPoint(&ep)) << b.error();
 
   auto preamble = b.preamble();
   ASSERT_EQ(preamble.size(), 1u);
@@ -101,8 +116,12 @@
                     EntryPointStageData{ast::PipelineStage::kCompute,
                                         SpvExecutionModelGLCompute}));
 
-TEST_F(BuilderTest, EntryPoint_WithInterfaceIds) {
+TEST_F(BuilderTest, EntryPoint_WithUnusedInterfaceIds) {
   ast::type::F32Type f32;
+  ast::type::VoidType void_type;
+
+  ast::Function func("main", {}, &void_type);
+
   auto v_in =
       std::make_unique<ast::Variable>("my_in", ast::StorageClass::kInput, &f32);
   auto v_out = std::make_unique<ast::Variable>(
@@ -121,11 +140,12 @@
   mod.AddGlobalVariable(std::move(v_out));
   mod.AddGlobalVariable(std::move(v_wg));
 
-  b.set_func_name_to_id("main", 3);
-  ASSERT_TRUE(b.GenerateEntryPoint(&ep));
+  ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
+  ASSERT_TRUE(b.GenerateEntryPoint(&ep)) << b.error();
   EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %1 "my_in"
 OpName %4 "my_out"
 OpName %7 "my_wg"
+OpName %11 "main"
 )");
   EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32
 %2 = OpTypePointer Input %3
@@ -135,35 +155,111 @@
 %4 = OpVariable %5 Output %6
 %8 = OpTypePointer Workgroup %3
 %7 = OpVariable %8 Workgroup
+%10 = OpTypeVoid
+%9 = OpTypeFunction %10
 )");
   EXPECT_EQ(DumpInstructions(b.preamble()),
-            R"(OpEntryPoint Vertex %3 "main" %1 %4
+            R"(OpEntryPoint Vertex %11 "main"
+)");
+}
+
+TEST_F(BuilderTest, EntryPoint_WithUsedInterfaceIds) {
+  ast::type::F32Type f32;
+  ast::type::VoidType void_type;
+
+  ast::Function func("main", {}, &void_type);
+  ast::StatementList body;
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("my_out"),
+      std::make_unique<ast::IdentifierExpression>("my_in")));
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("my_wg"),
+      std::make_unique<ast::IdentifierExpression>("my_wg")));
+  // Add duplicate usages so we show they don't get output multiple times.
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("my_out"),
+      std::make_unique<ast::IdentifierExpression>("my_in")));
+  func.set_body(std::move(body));
+
+  auto v_in =
+      std::make_unique<ast::Variable>("my_in", ast::StorageClass::kInput, &f32);
+  auto v_out = std::make_unique<ast::Variable>(
+      "my_out", ast::StorageClass::kOutput, &f32);
+  auto v_wg = std::make_unique<ast::Variable>(
+      "my_wg", ast::StorageClass::kWorkgroup, &f32);
+  ast::EntryPoint ep(ast::PipelineStage::kVertex, "", "main");
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(v_in.get());
+  td.RegisterVariableForTesting(v_out.get());
+  td.RegisterVariableForTesting(v_wg.get());
+
+  ASSERT_TRUE(td.DetermineFunction(&func)) << td.error();
+
+  Builder b(&mod);
+
+  EXPECT_TRUE(b.GenerateGlobalVariable(v_in.get())) << b.error();
+  EXPECT_TRUE(b.GenerateGlobalVariable(v_out.get())) << b.error();
+  EXPECT_TRUE(b.GenerateGlobalVariable(v_wg.get())) << b.error();
+
+  mod.AddGlobalVariable(std::move(v_in));
+  mod.AddGlobalVariable(std::move(v_out));
+  mod.AddGlobalVariable(std::move(v_wg));
+
+  ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
+  ASSERT_TRUE(b.GenerateEntryPoint(&ep)) << b.error();
+  EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %1 "my_in"
+OpName %4 "my_out"
+OpName %7 "my_wg"
+OpName %11 "main"
+)");
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32
+%2 = OpTypePointer Input %3
+%1 = OpVariable %2 Input
+%5 = OpTypePointer Output %3
+%6 = OpConstantNull %3
+%4 = OpVariable %5 Output %6
+%8 = OpTypePointer Workgroup %3
+%7 = OpVariable %8 Workgroup
+%10 = OpTypeVoid
+%9 = OpTypeFunction %10
+)");
+  EXPECT_EQ(DumpInstructions(b.preamble()),
+            R"(OpEntryPoint Vertex %11 "main" %4 %1
 )");
 }
 
 TEST_F(BuilderTest, ExecutionModel_Fragment_OriginUpperLeft) {
+  ast::type::VoidType void_type;
+
+  ast::Function func("frag_main", {}, &void_type);
   ast::EntryPoint ep(ast::PipelineStage::kFragment, "main", "frag_main");
 
   ast::Module mod;
   Builder b(&mod);
-  b.set_func_name_to_id("frag_main", 2);
+  ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
   ASSERT_TRUE(b.GenerateExecutionModes(&ep));
 
   EXPECT_EQ(DumpInstructions(b.preamble()),
-            R"(OpExecutionMode %2 OriginUpperLeft
+            R"(OpExecutionMode %3 OriginUpperLeft
 )");
 }
 
 TEST_F(BuilderTest, ExecutionModel_Compute_LocalSize) {
+  ast::type::VoidType void_type;
+
+  ast::Function func("main", {}, &void_type);
   ast::EntryPoint ep(ast::PipelineStage::kCompute, "main", "main");
 
   ast::Module mod;
   Builder b(&mod);
-  b.set_func_name_to_id("main", 2);
+  ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
   ASSERT_TRUE(b.GenerateExecutionModes(&ep));
 
   EXPECT_EQ(DumpInstructions(b.preamble()),
-            R"(OpExecutionMode %2 LocalSize 1 1 1
+            R"(OpExecutionMode %3 LocalSize 1 1 1
 )");
 }