[hlsl-writer] Add support for input locations and builtins.

This CL adds the beginning of support for input/output locations and
builtins in the HLSL backend.

Bug: tint:7
Change-Id: I8fb01707b50635a800b0d7317cf4a8f62f12cfca
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/26780
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: David Neto <dneto@google.com>
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index eb6edf4..6e53100 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -20,6 +20,7 @@
 #include "src/ast/binary_expression.h"
 #include "src/ast/bool_literal.h"
 #include "src/ast/case_statement.h"
+#include "src/ast/decorated_variable.h"
 #include "src/ast/else_statement.h"
 #include "src/ast/float_literal.h"
 #include "src/ast/identifier_expression.h"
@@ -46,6 +47,11 @@
 namespace hlsl {
 namespace {
 
+const char kInStructNameSuffix[] = "in";
+const char kOutStructNameSuffix[] = "out";
+const char kTintStructInVarPrefix[] = "tint_in";
+const char kTintStructOutVarPrefix[] = "tint_out";
+
 bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
   if (stmts->empty()) {
     return false;
@@ -74,6 +80,11 @@
     out_ << std::endl;
   }
 
+  for (const auto& ep : module_->entry_points()) {
+    if (!EmitEntryPointData(ep.get())) {
+      return false;
+    }
+  }
   for (const auto& func : module_->functions()) {
     if (!EmitFunction(func.get())) {
       return false;
@@ -89,6 +100,17 @@
   return true;
 }
 
+std::string GeneratorImpl::generate_name(const std::string& prefix) {
+  std::string name = prefix;
+  uint32_t i = 0;
+  while (namer_.IsMapped(name)) {
+    name = prefix + "_" + std::to_string(i);
+    ++i;
+  }
+  namer_.RegisterRemappedName(name);
+  return name;
+}
+
 std::string GeneratorImpl::current_ep_var_name(VarType type) {
   std::string name = "";
   switch (type) {
@@ -431,8 +453,12 @@
   return false;
 }
 
-bool GeneratorImpl::global_is_in_struct(ast::Variable*) const {
-  return false;
+bool GeneratorImpl::global_is_in_struct(ast::Variable* var) const {
+  return var->IsDecorated() &&
+         (var->AsDecorated()->HasLocationDecoration() ||
+          var->AsDecorated()->HasBuiltinDecoration()) &&
+         (var->storage_class() == ast::StorageClass::kInput ||
+          var->storage_class() == ast::StorageClass::kOutput);
 }
 
 bool GeneratorImpl::EmitIdentifier(ast::IdentifierExpression* expr) {
@@ -499,6 +525,25 @@
   return EmitBlock(stmt->body());
 }
 
+bool GeneratorImpl::has_referenced_var_needing_struct(ast::Function* func) {
+  for (auto data : func->referenced_location_variables()) {
+    auto* var = data.first;
+    if (var->storage_class() == ast::StorageClass::kOutput ||
+        var->storage_class() == ast::StorageClass::kInput) {
+      return true;
+    }
+  }
+
+  for (auto data : func->referenced_builtin_variables()) {
+    auto* var = data.first;
+    if (var->storage_class() == ast::StorageClass::kOutput ||
+        var->storage_class() == ast::StorageClass::kInput) {
+      return true;
+    }
+  }
+  return false;
+}
+
 bool GeneratorImpl::EmitFunction(ast::Function* func) {
   make_indent();
 
@@ -507,6 +552,33 @@
     return true;
   }
 
+  // TODO(dsinclair): This could be smarter. If the input/outputs for multiple
+  // entry points are the same we could generate a single struct and then have
+  // this determine it's the same struct and just emit once.
+  bool emit_duplicate_functions = func->ancestor_entry_points().size() > 0 &&
+                                  has_referenced_var_needing_struct(func);
+
+  if (emit_duplicate_functions) {
+    for (const auto& ep_name : func->ancestor_entry_points()) {
+      if (!EmitFunctionInternal(func, emit_duplicate_functions, ep_name)) {
+        return false;
+      }
+      out_ << std::endl;
+    }
+  } else {
+    // Emit as non-duplicated
+    if (!EmitFunctionInternal(func, false, "")) {
+      return false;
+    }
+    out_ << std::endl;
+  }
+
+  return true;
+}
+
+bool GeneratorImpl::EmitFunctionInternal(ast::Function* func,
+                                         bool emit_duplicate_functions,
+                                         const std::string& ep_name) {
   auto name = func->name();
 
   if (!EmitType(func->return_type(), "")) {
@@ -516,6 +588,30 @@
   out_ << " " << namer_.NameFor(name) << "(";
 
   bool first = true;
+
+  // If we're emitting duplicate functions that means the function takes
+  // the stage_in or stage_out value from the entry point, emit them.
+  //
+  // We emit both of them if they're there regardless of if they're both used.
+  if (emit_duplicate_functions) {
+    auto in_it = ep_name_to_in_data_.find(ep_name);
+    if (in_it != ep_name_to_in_data_.end()) {
+      out_ << "in " << in_it->second.struct_name << " "
+           << in_it->second.var_name;
+      first = false;
+    }
+
+    auto out_it = ep_name_to_out_data_.find(ep_name);
+    if (out_it != ep_name_to_out_data_.end()) {
+      if (!first) {
+        out_ << ", ";
+      }
+      out_ << "out " << out_it->second.struct_name << " "
+           << out_it->second.var_name;
+      first = false;
+    }
+  }
+
   for (const auto& v : func->params()) {
     if (!first) {
       out_ << ", ";
@@ -533,19 +629,188 @@
 
   out_ << ") ";
 
+  current_ep_name_ = ep_name;
+
   if (!EmitBlockAndNewline(func->body())) {
     return false;
   }
 
+  current_ep_name_ = "";
+
+  return true;
+}
+
+bool GeneratorImpl::EmitEntryPointData(ast::EntryPoint* ep) {
+  auto* func = module_->FindFunctionByName(ep->function_name());
+  if (func == nullptr) {
+    error_ = "Unable to find entry point function: " + ep->function_name();
+    return false;
+  }
+
+  std::vector<std::pair<ast::Variable*, ast::VariableDecoration*>> in_variables;
+  std::vector<std::pair<ast::Variable*, ast::VariableDecoration*>>
+      out_variables;
+  for (auto data : func->referenced_location_variables()) {
+    auto* var = data.first;
+    auto* deco = data.second;
+
+    if (var->storage_class() == ast::StorageClass::kInput) {
+      in_variables.push_back({var, deco});
+    } else if (var->storage_class() == ast::StorageClass::kOutput) {
+      out_variables.push_back({var, deco});
+    }
+  }
+
+  for (auto data : func->referenced_builtin_variables()) {
+    auto* var = data.first;
+    auto* deco = data.second;
+
+    if (var->storage_class() == ast::StorageClass::kInput) {
+      in_variables.push_back({var, deco});
+    } else if (var->storage_class() == ast::StorageClass::kOutput) {
+      out_variables.push_back({var, deco});
+    }
+  }
+
+  auto ep_name = ep->name();
+  if (ep_name.empty()) {
+    ep_name = ep->function_name();
+  }
+
+  // TODO(dsinclair): There is a potential bug here. Entry points can have the
+  // same name in WGSL if they have different pipeline stages. This does not
+  // take that into account and will emit duplicate struct names. I'm ignoring
+  // this until https://github.com/gpuweb/gpuweb/issues/662 is resolved as it
+  // may remove this issue and entry point names will need to be unique.
+  if (!in_variables.empty()) {
+    auto in_struct_name = generate_name(ep_name + "_" + kInStructNameSuffix);
+    auto in_var_name = generate_name(kTintStructInVarPrefix);
+    ep_name_to_in_data_[ep_name] = {in_struct_name, in_var_name};
+
+    make_indent();
+    out_ << "struct " << in_struct_name << " {" << std::endl;
+
+    increment_indent();
+
+    for (auto& data : in_variables) {
+      auto* var = data.first;
+      auto* deco = data.second;
+
+      make_indent();
+      if (!EmitType(var->type(), var->name())) {
+        return false;
+      }
+
+      out_ << " " << var->name() << " : ";
+      if (deco->IsLocation()) {
+        out_ << "TEXCOORD" << deco->AsLocation()->value();
+      } else if (deco->IsBuiltin()) {
+        auto attr = builtin_to_attribute(deco->AsBuiltin()->value());
+        if (attr.empty()) {
+          error_ = "unsupported builtin";
+          return false;
+        }
+        out_ << attr;
+      } else {
+        error_ = "unsupported variable decoration for entry point output";
+        return false;
+      }
+      out_ << ";" << std::endl;
+    }
+    decrement_indent();
+    make_indent();
+
+    out_ << "};" << std::endl << std::endl;
+  }
+
+  if (!out_variables.empty()) {
+    auto out_struct_name = generate_name(ep_name + "_" + kOutStructNameSuffix);
+    auto out_var_name = generate_name(kTintStructOutVarPrefix);
+    ep_name_to_out_data_[ep_name] = {out_struct_name, out_var_name};
+
+    make_indent();
+    out_ << "struct " << out_struct_name << " {" << std::endl;
+
+    increment_indent();
+    for (auto& data : out_variables) {
+      auto* var = data.first;
+      auto* deco = data.second;
+
+      make_indent();
+      if (!EmitType(var->type(), var->name())) {
+        return false;
+      }
+
+      out_ << " " << var->name() << " : ";
+
+      if (deco->IsLocation()) {
+        auto loc = deco->AsLocation()->value();
+        if (ep->stage() == ast::PipelineStage::kVertex) {
+          out_ << "TEXCOORD" << loc;
+        } else if (ep->stage() == ast::PipelineStage::kFragment) {
+          out_ << "SV_Target" << loc << "";
+        } else {
+          error_ = "invalid location variable for pipeline stage";
+          return false;
+        }
+      } else if (deco->IsBuiltin()) {
+        auto attr = builtin_to_attribute(deco->AsBuiltin()->value());
+        if (attr.empty()) {
+          error_ = "unsupported builtin";
+          return false;
+        }
+        out_ << attr;
+      } else {
+        error_ = "unsupported variable decoration for entry point output";
+        return false;
+      }
+      out_ << ";" << std::endl;
+    }
+    decrement_indent();
+    make_indent();
+    out_ << "};" << std::endl << std::endl;
+  }
+
   return true;
 }
 
+std::string GeneratorImpl::builtin_to_attribute(ast::Builtin builtin) const {
+  switch (builtin) {
+    case ast::Builtin::kPosition:
+      return "SV_Position";
+    case ast::Builtin::kVertexIdx:
+      return "SV_VertexID";
+    case ast::Builtin::kInstanceIdx:
+      return "SV_InstanceID";
+    case ast::Builtin::kFrontFacing:
+      return "SV_IsFrontFacing";
+    case ast::Builtin::kFragCoord:
+      return "SV_Position";
+    case ast::Builtin::kFragDepth:
+      return "SV_Depth";
+    // TODO(dsinclair): Ignore for now. This has been removed as a builtin
+    // in the spec. Need to update Tint to match.
+    // https://github.com/gpuweb/gpuweb/pull/824
+    case ast::Builtin::kWorkgroupSize:
+      return "";
+    case ast::Builtin::kLocalInvocationId:
+      return "SV_GroupThreadID";
+    case ast::Builtin::kLocalInvocationIdx:
+      return "SV_GroupIndex";
+    case ast::Builtin::kGlobalInvocationId:
+      return "SV_DispatchThreadID";
+    default:
+      break;
+  }
+  return "";
+}
+
 bool GeneratorImpl::EmitEntryPointFunction(ast::EntryPoint* ep) {
   make_indent();
 
-  auto current_ep_name = ep->name();
-  if (current_ep_name.empty()) {
-    current_ep_name = ep->function_name();
+  current_ep_name_ = ep->name();
+  if (current_ep_name_.empty()) {
+    current_ep_name_ = ep->function_name();
   }
 
   auto* func = module_->FindFunctionByName(ep->function_name());
@@ -554,19 +819,43 @@
     return false;
   }
 
-  out_ << "void " << namer_.NameFor(current_ep_name) << "() {" << std::endl;
+  auto out_data = ep_name_to_out_data_.find(current_ep_name_);
+  bool has_out_data = out_data != ep_name_to_out_data_.end();
+  if (has_out_data) {
+    out_ << out_data->second.struct_name;
+  } else {
+    out_ << "void";
+  }
+  out_ << " " << namer_.NameFor(current_ep_name_) << "(";
+
+  auto in_data = ep_name_to_in_data_.find(current_ep_name_);
+  if (in_data != ep_name_to_in_data_.end()) {
+    out_ << in_data->second.struct_name << " " << in_data->second.var_name;
+  }
+  out_ << ") {" << std::endl;
+
   increment_indent();
 
+  if (has_out_data) {
+    make_indent();
+    out_ << out_data->second.struct_name << " " << out_data->second.var_name
+         << ";" << std::endl;
+  }
+
+  generating_entry_point_ = true;
   for (const auto& s : *(func->body())) {
     if (!EmitStatement(s.get())) {
       return false;
     }
   }
+  generating_entry_point_ = false;
 
   decrement_indent();
   make_indent();
   out_ << "}" << std::endl;
 
+  current_ep_name_ = "";
+
   return true;
 }
 
diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h
index c24f977..b65b94d 100644
--- a/src/writer/hlsl/generator_impl.h
+++ b/src/writer/hlsl/generator_impl.h
@@ -110,6 +110,19 @@
   /// @param func the function to generate
   /// @returns true if the function was emitted
   bool EmitFunction(ast::Function* func);
+  /// Internal helper for emitting functions
+  /// @param func the function to emit
+  /// @param emit_duplicate_functions set true if we need to duplicate per entry
+  /// point
+  /// @param ep_name the current entry point or blank if none set
+  /// @returns true if the function was emitted.
+  bool EmitFunctionInternal(ast::Function* func,
+                            bool emit_duplicate_functions,
+                            const std::string& ep_name);
+  /// Handles emitting information for an entry point
+  /// @param ep the entry point
+  /// @returns true if the entry point data was emitted
+  bool EmitEntryPointData(ast::EntryPoint* ep);
   /// Handles emitting the entry point function
   /// @param ep the entry point
   /// @returns true if the entry point function was emitted
@@ -168,6 +181,21 @@
   /// @param var the variable to check
   /// @returns true if the global is in an input or output struct
   bool global_is_in_struct(ast::Variable* var) const;
+  /// Generates a name for the prefix
+  /// @param prefix the prefix of the name to generate
+  /// @returns the name
+  std::string generate_name(const std::string& prefix);
+  /// Converts a builtin to an attribute name
+  /// @param builtin the builtin to convert
+  /// @returns the string name of the builtin or blank on error
+  std::string builtin_to_attribute(ast::Builtin builtin) const;
+  /// Determines if any used module variable requires an input or output struct.
+  /// @param func the function to check
+  /// @returns true if an input or output struct is required.
+  bool has_referenced_var_needing_struct(ast::Function* func);
+
+  /// @returns the namer for testing
+  Namer* namer_for_testing() { return &namer_; }
 
  private:
   enum class VarType { kIn, kOut };
diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc
index 2216df5..0983554 100644
--- a/src/writer/hlsl/generator_impl_function_test.cc
+++ b/src/writer/hlsl/generator_impl_function_test.cc
@@ -67,6 +67,7 @@
   EXPECT_EQ(g.result(), R"(  void my_func() {
     return;
   }
+
 )");
 }
 
@@ -90,6 +91,7 @@
   EXPECT_EQ(g.result(), R"(  void GeometryShader_tint_0() {
     return;
   }
+
 )");
 }
 
@@ -121,6 +123,7 @@
   EXPECT_EQ(g.result(), R"(  void my_func(float a, int b) {
     return;
   }
+
 )");
 }
 
@@ -144,7 +147,7 @@
 )");
 }
 
-TEST_F(HlslGeneratorImplTest, DISABLED_Emit_Function_EntryPoint_WithInOutVars) {
+TEST_F(HlslGeneratorImplTest, Emit_Function_EntryPoint_WithInOutVars) {
   ast::type::VoidType void_type;
   ast::type::F32Type f32;
 
@@ -207,8 +210,7 @@
 )");
 }
 
-TEST_F(HlslGeneratorImplTest,
-       DISABLED_Emit_Function_EntryPoint_WithInOut_Builtins) {
+TEST_F(HlslGeneratorImplTest, Emit_Function_EntryPoint_WithInOut_Builtins) {
   ast::type::VoidType void_type;
   ast::type::F32Type f32;
   ast::type::VectorType vec4(&f32, 4);
@@ -262,7 +264,7 @@
   GeneratorImpl g(&mod);
   ASSERT_TRUE(g.Generate()) << g.error();
   EXPECT_EQ(g.result(), R"(struct frag_main_in {
-  float gl_FragCoord : SV_Position;
+  vector<float, 4> coord : SV_Position;
 };
 
 struct frag_main_out {
@@ -271,7 +273,7 @@
 
 frag_main_out frag_main(frag_main_in tint_in) {
   frag_main_out tint_out;
-  tint_out.depth = tint_in.gl_FragCoord.x;
+  tint_out.depth = tint_in.coord.x;
   return tint_out;
 }
 
@@ -377,6 +379,7 @@
   EXPECT_EQ(g.result(), R"( ... )");
 }
 
+// TODO(dsinclair): Requires CallExpression
 TEST_F(
     HlslGeneratorImplTest,
     DISABLED_Emit_Function_Called_By_EntryPoints_WithLocationGlobals_And_Params) {
@@ -480,6 +483,7 @@
 )");
 }
 
+// TODO(dsinclair): Requires CallExpression
 TEST_F(HlslGeneratorImplTest,
        DISABLED_Emit_Function_Called_By_EntryPoints_NoUsedGlobals) {
   ast::type::VoidType void_type;
@@ -558,6 +562,7 @@
 )");
 }
 
+// TODO(dsinclair): Requires CallExpression
 TEST_F(
     HlslGeneratorImplTest,
     DISABLED_Emit_Function_Called_By_EntryPoints_WithBuiltinGlobals_And_Params) {
@@ -1098,6 +1103,7 @@
   EXPECT_EQ(g.result(), R"(  void my_func(float a[5]) {
     return;
   }
+
 )");
 }
 
diff --git a/src/writer/hlsl/generator_impl_test.cc b/src/writer/hlsl/generator_impl_test.cc
index df128ad..b1b16db 100644
--- a/src/writer/hlsl/generator_impl_test.cc
+++ b/src/writer/hlsl/generator_impl_test.cc
@@ -19,6 +19,7 @@
 #include "gtest/gtest.h"
 #include "src/ast/entry_point.h"
 #include "src/ast/function.h"
+#include "src/ast/identifier_expression.h"
 #include "src/ast/module.h"
 #include "src/ast/type/void_type.h"
 
@@ -46,6 +47,66 @@
 )");
 }
 
+TEST_F(HlslGeneratorImplTest, InputStructName) {
+  ast::Module m;
+  GeneratorImpl g(&m);
+  ASSERT_EQ(g.generate_name("func_main_in"), "func_main_in");
+}
+
+TEST_F(HlslGeneratorImplTest, InputStructName_ConflictWithExisting) {
+  ast::Module m;
+  GeneratorImpl g(&m);
+
+  // Register the struct name as existing.
+  auto* namer = g.namer_for_testing();
+  namer->NameFor("func_main_out");
+
+  ASSERT_EQ(g.generate_name("func_main_out"), "func_main_out_0");
+}
+
+TEST_F(HlslGeneratorImplTest, NameConflictWith_InputStructName) {
+  ast::Module m;
+  GeneratorImpl g(&m);
+  ASSERT_EQ(g.generate_name("func_main_in"), "func_main_in");
+
+  ast::IdentifierExpression ident("func_main_in");
+  ASSERT_TRUE(g.EmitIdentifier(&ident));
+  EXPECT_EQ(g.result(), "func_main_in_0");
+}
+
+struct HlslBuiltinData {
+  ast::Builtin builtin;
+  const char* attribute_name;
+};
+inline std::ostream& operator<<(std::ostream& out, HlslBuiltinData data) {
+  out << data.builtin;
+  return out;
+}
+using HlslBuiltinConversionTest = testing::TestWithParam<HlslBuiltinData>;
+TEST_P(HlslBuiltinConversionTest, Emit) {
+  auto params = GetParam();
+
+  ast::Module m;
+  GeneratorImpl g(&m);
+  EXPECT_EQ(g.builtin_to_attribute(params.builtin),
+            std::string(params.attribute_name));
+}
+INSTANTIATE_TEST_SUITE_P(
+    HlslGeneratorImplTest,
+    HlslBuiltinConversionTest,
+    testing::Values(
+        HlslBuiltinData{ast::Builtin::kPosition, "SV_Position"},
+        HlslBuiltinData{ast::Builtin::kVertexIdx, "SV_VertexID"},
+        HlslBuiltinData{ast::Builtin::kInstanceIdx, "SV_InstanceID"},
+        HlslBuiltinData{ast::Builtin::kFrontFacing, "SV_IsFrontFacing"},
+        HlslBuiltinData{ast::Builtin::kFragCoord, "SV_Position"},
+        HlslBuiltinData{ast::Builtin::kFragDepth, "SV_Depth"},
+        HlslBuiltinData{ast::Builtin::kWorkgroupSize, ""},
+        HlslBuiltinData{ast::Builtin::kLocalInvocationId, "SV_GroupThreadID"},
+        HlslBuiltinData{ast::Builtin::kLocalInvocationIdx, "SV_GroupIndex"},
+        HlslBuiltinData{ast::Builtin::kGlobalInvocationId,
+                        "SV_DispatchThreadID"}));
+
 }  // namespace
 }  // namespace hlsl
 }  // namespace writer