[msl-writer] Handle uniform buffers.

This CL adds support for handling uniform data. Currently the uniform is
added to a buffer where the number is the binding value. This will need
to be updated to accept the correct mapping from the embedder.

Bug: tint:8
Change-Id: Icccccbe599a9555defa6136e384745f4093df020
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/25104
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/ast/function.cc b/src/ast/function.cc
index 3ca975e..44d5603 100644
--- a/src/ast/function.cc
+++ b/src/ast/function.cc
@@ -71,6 +71,34 @@
   return ret;
 }
 
+const std::vector<std::pair<Variable*, Function::BindingInfo>>
+Function::referenced_uniform_variables() const {
+  std::vector<std::pair<Variable*, Function::BindingInfo>> ret;
+
+  for (auto* var : referenced_module_variables()) {
+    if (!var->IsDecorated() ||
+        var->storage_class() != ast::StorageClass::kUniform) {
+      continue;
+    }
+
+    BindingDecoration* binding = nullptr;
+    SetDecoration* set = nullptr;
+    for (const auto& deco : var->AsDecorated()->decorations()) {
+      if (deco->IsBinding()) {
+        binding = deco->AsBinding();
+      } else if (deco->IsSet()) {
+        set = deco->AsSet();
+      }
+    }
+    if (binding == nullptr || set == nullptr) {
+      continue;
+    }
+
+    ret.push_back({var, BindingInfo{binding, set}});
+  }
+  return ret;
+}
+
 const std::vector<std::pair<Variable*, BuiltinDecoration*>>
 Function::referenced_builtin_variables() const {
   std::vector<std::pair<Variable*, BuiltinDecoration*>> ret;
diff --git a/src/ast/function.h b/src/ast/function.h
index a3722e1..4a8c33f 100644
--- a/src/ast/function.h
+++ b/src/ast/function.h
@@ -21,10 +21,12 @@
 #include <utility>
 #include <vector>
 
+#include "src/ast/binding_decoration.h"
 #include "src/ast/builtin_decoration.h"
 #include "src/ast/expression.h"
 #include "src/ast/location_decoration.h"
 #include "src/ast/node.h"
+#include "src/ast/set_decoration.h"
 #include "src/ast/statement.h"
 #include "src/ast/type/type.h"
 #include "src/ast/variable.h"
@@ -35,6 +37,14 @@
 /// A Function statement.
 class Function : public Node {
  public:
+  /// Information about a binding
+  struct BindingInfo {
+    /// The binding decoration
+    BindingDecoration* binding = nullptr;
+    /// The set decoration
+    SetDecoration* set = nullptr;
+  };
+
   /// Create a new empty function statement
   Function();
   /// Create a function
@@ -86,6 +96,11 @@
   /// @returns the <variable, decoration> pair.
   const std::vector<std::pair<Variable*, BuiltinDecoration*>>
   referenced_builtin_variables() const;
+  /// Retrieves any referenced uniform variables. Note, the uniform must be
+  /// decorated with both binding and set decorations.
+  /// @returns the referenced uniforms
+  const std::vector<std::pair<Variable*, Function::BindingInfo>>
+  referenced_uniform_variables() const;
 
   /// Adds an ancestor entry point
   /// @param ep the entry point ancestor
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 6f8de07..38e4e60 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -2915,7 +2915,7 @@
         type_mgr_->FindPointerToType(pointee_type_id, storage_class);
     auto* ast_pointer_type = parser_impl_.ConvertType(pointer_type_id);
     assert(ast_pointer_type);
-    assert(ast_pointer_type->IsPointer);
+    assert(ast_pointer_type->IsPointer());
     current_expr.reset(TypedExpression(ast_pointer_type, std::move(next_expr)));
   }
   return current_expr;
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index 46ce013..dbfff5c 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -448,6 +448,7 @@
       error_ = "Unable to find function: " + name;
       return false;
     }
+
     for (const auto& data : func->referenced_builtin_variables()) {
       auto* var = data.first;
       if (var->storage_class() != ast::StorageClass::kInput) {
@@ -460,6 +461,15 @@
       out_ << var->name();
     }
 
+    for (const auto& data : func->referenced_uniform_variables()) {
+      auto* var = data.first;
+      if (!first) {
+        out_ << ", ";
+      }
+      first = false;
+      out_ << var->name();
+    }
+
     const auto& params = expr->params();
     for (const auto& param : params) {
       if (!first) {
@@ -814,6 +824,25 @@
   return;
 }
 
+bool GeneratorImpl::has_referenced_var_needing_struct(ast::Function* func) {
+  for (auto data : func->referenced_location_variables()) {
+    auto var = data.first;
+    if (var->storage_class() == ast::StorageClass::kInput ||
+        var->storage_class() == ast::StorageClass::kOutput) {
+      return true;
+    }
+  }
+
+  for (auto data : func->referenced_builtin_variables()) {
+    auto var = data.first;
+    if (var->storage_class() == ast::StorageClass::kOutput) {
+      return true;
+    }
+  }
+
+  return false;
+}
+
 bool GeneratorImpl::EmitFunction(ast::Function* func) {
   make_indent();
 
@@ -825,9 +854,8 @@
   // TODO(dsinclair): This could be smarter. If the input/outputs for multiple
   // entry points are the same we could generate a single struct and then have
   // this determine it's the same struct and just emit once.
-  bool emit_duplicate_functions =
-      func->ancestor_entry_points().size() > 0 &&
-      func->referenced_module_variables().size() > 0;
+  bool emit_duplicate_functions = func->ancestor_entry_points().size() > 0 &&
+                                  has_referenced_var_needing_struct(func);
 
   if (emit_duplicate_functions) {
     for (const auto& ep_name : func->ancestor_entry_points()) {
@@ -857,7 +885,6 @@
   }
 
   out_ << " ";
-
   if (emit_duplicate_functions) {
     name = generate_name(name + "_" + ep_name);
     ep_func_name_remapped_[ep_name + "_" + func->name()] = name;
@@ -908,6 +935,21 @@
     out_ << "& " << var->name();
   }
 
+  for (const auto& data : func->referenced_uniform_variables()) {
+    auto* var = data.first;
+    if (!first) {
+      out_ << ", ";
+    }
+    first = false;
+
+    out_ << "constant ";
+    // TODO(dsinclair): Can arrays be uniform? If so, fix this ...
+    if (!EmitType(var->type(), "")) {
+      return false;
+    }
+    out_ << "& " << var->name();
+  }
+
   // TODO(dsinclair): Binding/Set inputs
 
   for (const auto& v : func->params()) {
@@ -1034,6 +1076,28 @@
     out_ << " " << var->name() << " [[" << attr << "]]";
   }
 
+  for (auto data : func->referenced_uniform_variables()) {
+    if (!first) {
+      out_ << ", ";
+    }
+    first = false;
+
+    auto* var = data.first;
+    // TODO(dsinclair): We're using the binding to make up the buffer number but
+    // we should instead be using a provided mapping that uses both buffer and
+    // set. https://bugs.chromium.org/p/tint/issues/detail?id=104
+    auto* binding = data.second.binding;
+    // auto* set = data.second.set;
+
+    out_ << "constant ";
+    // TODO(dsinclair): Can you have a uniform array? If so, this needs to be
+    // updated to handle arrays property.
+    if (!EmitType(var->type(), "")) {
+      return false;
+    }
+    out_ << "& " << var->name() << " [[buffer(" << binding->value() << ")]]";
+  }
+
   // TODO(dsinclair): Binding/Set inputs
 
   out_ << ") {" << std::endl;
diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h
index df90697..094ba80 100644
--- a/src/writer/msl/generator_impl.h
+++ b/src/writer/msl/generator_impl.h
@@ -208,6 +208,11 @@
   /// @param mod the module to set.
   void set_module_for_testing(ast::Module* mod);
 
+  /// Determines if any used module variable requires an input or output struct.
+  /// @param func the function to check
+  /// @returns true if an input or output struct is required.
+  bool has_referenced_var_needing_struct(ast::Function* func);
+
   /// Generates a name for the prefix
   /// @param prefix the prefix of the name to generate
   /// @returns the name
diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc
index 5e25073..0558247 100644
--- a/src/writer/msl/generator_impl_function_test.cc
+++ b/src/writer/msl/generator_impl_function_test.cc
@@ -15,6 +15,7 @@
 #include "gtest/gtest.h"
 #include "src/ast/assignment_statement.h"
 #include "src/ast/binary_expression.h"
+#include "src/ast/binding_decoration.h"
 #include "src/ast/call_expression.h"
 #include "src/ast/decorated_variable.h"
 #include "src/ast/float_literal.h"
@@ -26,6 +27,7 @@
 #include "src/ast/module.h"
 #include "src/ast/return_statement.h"
 #include "src/ast/scalar_constructor_expression.h"
+#include "src/ast/set_decoration.h"
 #include "src/ast/sint_literal.h"
 #include "src/ast/type/array_type.h"
 #include "src/ast/type/f32_type.h"
@@ -286,6 +288,62 @@
 )");
 }
 
+TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_With_Uniform) {
+  ast::type::VoidType void_type;
+  ast::type::F32Type f32;
+  ast::type::VectorType vec4(&f32, 4);
+
+  auto coord_var =
+      std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+          "coord", ast::StorageClass::kUniform, &vec4));
+
+  ast::VariableDecorationList decos;
+  decos.push_back(std::make_unique<ast::BindingDecoration>(0));
+  decos.push_back(std::make_unique<ast::SetDecoration>(1));
+  coord_var->set_decorations(std::move(decos));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(coord_var.get());
+
+  mod.AddGlobalVariable(std::move(coord_var));
+
+  ast::VariableList params;
+  auto func = std::make_unique<ast::Function>("frag_main", std::move(params),
+                                              &void_type);
+
+  auto var =
+      std::make_unique<ast::Variable>("v", ast::StorageClass::kFunction, &f32);
+  var->set_constructor(std::make_unique<ast::MemberAccessorExpression>(
+      std::make_unique<ast::IdentifierExpression>("coord"),
+      std::make_unique<ast::IdentifierExpression>("x")));
+
+  ast::StatementList body;
+  body.push_back(std::make_unique<ast::VariableDeclStatement>(std::move(var)));
+  body.push_back(std::make_unique<ast::ReturnStatement>());
+  func->set_body(std::move(body));
+
+  mod.AddFunction(std::move(func));
+
+  auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment, "",
+                                              "frag_main");
+  mod.AddEntryPoint(std::move(ep));
+
+  ASSERT_TRUE(td.Determine()) << td.error();
+
+  GeneratorImpl g;
+  ASSERT_TRUE(g.Generate(mod)) << g.error();
+  EXPECT_EQ(g.result(), R"(#include <metal_stdlib>
+
+fragment void frag_main(constant float4& coord [[buffer(0)]]) {
+  float v = coord.x;
+  return;
+}
+
+)");
+}
+
 TEST_F(MslGeneratorImplTest,
        Emit_Function_Called_By_EntryPoints_WithLocationGlobals_And_Params) {
   ast::type::VoidType void_type;
@@ -481,6 +539,83 @@
 )");
 }
 
+TEST_F(MslGeneratorImplTest, Emit_Function_Called_By_EntryPoint_With_Uniform) {
+  ast::type::VoidType void_type;
+  ast::type::F32Type f32;
+  ast::type::VectorType vec4(&f32, 4);
+
+  auto coord_var =
+      std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+          "coord", ast::StorageClass::kUniform, &vec4));
+
+  ast::VariableDecorationList decos;
+  decos.push_back(std::make_unique<ast::BindingDecoration>(0));
+  decos.push_back(std::make_unique<ast::SetDecoration>(1));
+  coord_var->set_decorations(std::move(decos));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(coord_var.get());
+
+  mod.AddGlobalVariable(std::move(coord_var));
+
+  ast::VariableList params;
+  params.push_back(std::make_unique<ast::Variable>(
+      "param", ast::StorageClass::kFunction, &f32));
+  auto sub_func =
+      std::make_unique<ast::Function>("sub_func", std::move(params), &f32);
+
+  ast::StatementList body;
+  body.push_back(std::make_unique<ast::ReturnStatement>(
+      std::make_unique<ast::MemberAccessorExpression>(
+          std::make_unique<ast::IdentifierExpression>("coord"),
+          std::make_unique<ast::IdentifierExpression>("x"))));
+  sub_func->set_body(std::move(body));
+
+  mod.AddFunction(std::move(sub_func));
+
+  auto func = std::make_unique<ast::Function>("frag_main", std::move(params),
+                                              &void_type);
+
+  ast::ExpressionList expr;
+  expr.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+
+  auto var =
+      std::make_unique<ast::Variable>("v", ast::StorageClass::kFunction, &f32);
+  var->set_constructor(std::make_unique<ast::CallExpression>(
+      std::make_unique<ast::IdentifierExpression>("sub_func"),
+      std::move(expr)));
+
+  body.push_back(std::make_unique<ast::VariableDeclStatement>(std::move(var)));
+  body.push_back(std::make_unique<ast::ReturnStatement>());
+  func->set_body(std::move(body));
+
+  mod.AddFunction(std::move(func));
+
+  auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment, "",
+                                              "frag_main");
+  mod.AddEntryPoint(std::move(ep));
+
+  ASSERT_TRUE(td.Determine()) << td.error();
+
+  GeneratorImpl g;
+  ASSERT_TRUE(g.Generate(mod)) << g.error();
+  EXPECT_EQ(g.result(), R"(#include <metal_stdlib>
+
+float sub_func(constant float4& coord, float param) {
+  return coord.x;
+}
+
+fragment void frag_main(constant float4& coord [[buffer(0)]]) {
+  float v = sub_func(coord, 1.00000000f);
+  return;
+}
+
+)");
+}
+
 TEST_F(MslGeneratorImplTest, Emit_Function_Called_Two_EntryPoints_WithGlobals) {
   ast::type::VoidType void_type;
   ast::type::F32Type f32;