[ast][spirv-writer][hlsl-writer][wgsl-writer] Add workgroup_size decoration

This CL adds the workgroup_size decoration to functions and emits as
needed from the various backends.

Change-Id: Ifffde239e68047f6419c6980eca70c4efa9822c0
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/28662
Reviewed-by: Sarah Mashayekhi <sarahmashay@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/BUILD.gn b/BUILD.gn
index 9833709..a37e2a0 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -369,6 +369,8 @@
     "src/ast/variable_decl_statement.h",
     "src/ast/variable_decoration.cc",
     "src/ast/variable_decoration.h",
+    "src/ast/workgroup_decoration.cc",
+    "src/ast/workgroup_decoration.h",
     "src/context.cc",
     "src/context.h",
     "src/reader/reader.cc",
@@ -744,6 +746,7 @@
     "src/ast/unary_op_expression_test.cc",
     "src/ast/variable_decl_statement_test.cc",
     "src/ast/variable_test.cc",
+    "src/ast/workgroup_decoration_test.cc",
     "src/scope_stack_test.cc",
     "src/type_determiner_test.cc",
     "src/type_manager_test.cc",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index f821429..04fbea4 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -190,6 +190,8 @@
   ast/variable_decoration.h
   ast/variable_decl_statement.cc
   ast/variable_decl_statement.h
+  ast/workgroup_decoration.cc
+  ast/workgroup_decoration.h
   context.h
   context.cc
   reader/reader.cc
@@ -354,6 +356,7 @@
   ast/unary_op_expression_test.cc
   ast/variable_decl_statement_test.cc
   ast/variable_test.cc
+  ast/workgroup_decoration_test.cc
   scope_stack_test.cc
   type_determiner_test.cc
   type_manager_test.cc
diff --git a/src/ast/function.cc b/src/ast/function.cc
index 0e52bf9..8e8d853 100644
--- a/src/ast/function.cc
+++ b/src/ast/function.cc
@@ -17,6 +17,7 @@
 #include <sstream>
 
 #include "src/ast/decorated_variable.h"
+#include "src/ast/workgroup_decoration.h"
 
 namespace tint {
 namespace ast {
@@ -46,6 +47,15 @@
 
 Function::~Function() = default;
 
+std::tuple<uint32_t, uint32_t, uint32_t> Function::workgroup_size() const {
+  for (const auto& deco : decorations_) {
+    if (deco->IsWorkgroup()) {
+      return deco->AsWorkgroup()->values();
+    }
+  }
+  return {1, 1, 1};
+}
+
 void Function::add_referenced_module_variable(Variable* var) {
   for (const auto* v : referenced_module_vars_) {
     if (v->name() == var->name()) {
diff --git a/src/ast/function.h b/src/ast/function.h
index b292464..d64c023 100644
--- a/src/ast/function.h
+++ b/src/ast/function.h
@@ -90,6 +90,10 @@
   /// @returns the decorations attached to this function
   const FunctionDecorationList& decorations() const { return decorations_; }
 
+  /// @returns the workgroup size {x, y, z} for the function. {1, 1, 1} will be
+  /// return if no workgroup size was set.
+  std::tuple<uint32_t, uint32_t, uint32_t> workgroup_size() const;
+
   /// Adds the given variable to the list of referenced module variables if it
   /// is not already included.
   /// @param var the module variable to add
diff --git a/src/ast/function_decoration.cc b/src/ast/function_decoration.cc
index a447297..c246b97 100644
--- a/src/ast/function_decoration.cc
+++ b/src/ast/function_decoration.cc
@@ -14,6 +14,10 @@
 
 #include "src/ast/function_decoration.h"
 
+#include <assert.h>
+
+#include "src/ast/workgroup_decoration.h"
+
 namespace tint {
 namespace ast {
 
@@ -21,5 +25,14 @@
 
 FunctionDecoration::~FunctionDecoration() = default;
 
+bool FunctionDecoration::IsWorkgroup() const {
+  return false;
+}
+
+const WorkgroupDecoration* FunctionDecoration::AsWorkgroup() const {
+  assert(IsWorkgroup());
+  return static_cast<const WorkgroupDecoration*>(this);
+}
+
 }  // namespace ast
 }  // namespace tint
diff --git a/src/ast/function_decoration.h b/src/ast/function_decoration.h
index 4b75f30..461a037 100644
--- a/src/ast/function_decoration.h
+++ b/src/ast/function_decoration.h
@@ -22,11 +22,19 @@
 namespace tint {
 namespace ast {
 
+class WorkgroupDecoration;
+
 /// A decoration attached to a function
 class FunctionDecoration {
  public:
   virtual ~FunctionDecoration();
 
+  /// @returns true if this is a workgroup decoration
+  virtual bool IsWorkgroup() const;
+
+  /// @returns the decoration as a workgroup decoration
+  const WorkgroupDecoration* AsWorkgroup() const;
+
   /// Outputs the function decoration to the given stream
   /// @param out the stream to output too
   virtual void to_str(std::ostream& out) const = 0;
diff --git a/src/ast/function_test.cc b/src/ast/function_test.cc
index a75a619..20d37ed 100644
--- a/src/ast/function_test.cc
+++ b/src/ast/function_test.cc
@@ -24,7 +24,7 @@
 #include "src/ast/type/i32_type.h"
 #include "src/ast/type/void_type.h"
 #include "src/ast/variable.h"
-// #include "src/ast/workgroup_decoration.h"
+#include "src/ast/workgroup_decoration.h"
 
 namespace tint {
 namespace ast {
@@ -298,27 +298,27 @@
 )");
 }
 
-// TEST_F(FunctionTest, ToStr_WithDecoration) {
-//   type::VoidType void_type;
-//   type::I32Type i32;
+TEST_F(FunctionTest, ToStr_WithDecoration) {
+  type::VoidType void_type;
+  type::I32Type i32;
 
-//   auto block = std::make_unique<ast::BlockStatement>();
-//   block->append(std::make_unique<DiscardStatement>());
+  auto block = std::make_unique<ast::BlockStatement>();
+  block->append(std::make_unique<DiscardStatement>());
 
-//   Function f("func", {}, &void_type);
-//   f.set_body(std::move(block));
-//   f.add_decoration(std::make_unique<WorkgroupDecoration>(2, 4, 6));
+  Function f("func", {}, &void_type);
+  f.set_body(std::move(block));
+  f.add_decoration(std::make_unique<WorkgroupDecoration>(2, 4, 6));
 
-//   std::ostringstream out;
-//   f.to_str(out, 2);
-//   EXPECT_EQ(out.str(), R"(  Function func -> __void
-//   workgroup_size 2 4 6
-//   ()
-//   {
-//     Discard{}
-//   }
-// )");
-// }
+  std::ostringstream out;
+  f.to_str(out, 2);
+  EXPECT_EQ(out.str(), R"(  Function func -> __void
+  WorkgroupDecoration{2 4 6}
+  ()
+  {
+    Discard{}
+  }
+)");
+}
 
 TEST_F(FunctionTest, ToStr_WithParams) {
   type::VoidType void_type;
@@ -396,6 +396,33 @@
 
   EXPECT_EQ(f.get_last_statement(), nullptr);
 }
+
+TEST_F(FunctionTest, WorkgroupSize_NoneSet) {
+  type::VoidType void_type;
+  Function f("f", {}, &void_type);
+  uint32_t x = 0;
+  uint32_t y = 0;
+  uint32_t z = 0;
+  std::tie(x, y, z) = f.workgroup_size();
+  EXPECT_EQ(x, 1u);
+  EXPECT_EQ(y, 1u);
+  EXPECT_EQ(z, 1u);
+}
+
+TEST_F(FunctionTest, WorkgroupSize) {
+  type::VoidType void_type;
+  Function f("f", {}, &void_type);
+  f.add_decoration(std::make_unique<WorkgroupDecoration>(2u, 4u, 6u));
+
+  uint32_t x = 0;
+  uint32_t y = 0;
+  uint32_t z = 0;
+  std::tie(x, y, z) = f.workgroup_size();
+  EXPECT_EQ(x, 2u);
+  EXPECT_EQ(y, 4u);
+  EXPECT_EQ(z, 6u);
+}
+
 }  // namespace
 }  // namespace ast
 }  // namespace tint
diff --git a/src/ast/workgroup_decoration.cc b/src/ast/workgroup_decoration.cc
new file mode 100644
index 0000000..fd44db4
--- /dev/null
+++ b/src/ast/workgroup_decoration.cc
@@ -0,0 +1,40 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/ast/workgroup_decoration.h"
+
+namespace tint {
+namespace ast {
+
+WorkgroupDecoration::WorkgroupDecoration(uint32_t x) : x_(x) {}
+
+WorkgroupDecoration::WorkgroupDecoration(uint32_t x, uint32_t y)
+    : x_(x), y_(y) {}
+
+WorkgroupDecoration::WorkgroupDecoration(uint32_t x, uint32_t y, uint32_t z)
+    : x_(x), y_(y), z_(z) {}
+
+WorkgroupDecoration::~WorkgroupDecoration() = default;
+
+bool WorkgroupDecoration::IsWorkgroup() const {
+  return true;
+}
+
+void WorkgroupDecoration::to_str(std::ostream& out) const {
+  out << "WorkgroupDecoration{" << x_ << " " << y_ << " " << z_ << "}"
+      << std::endl;
+}
+
+}  // namespace ast
+}  // namespace tint
diff --git a/src/ast/workgroup_decoration.h b/src/ast/workgroup_decoration.h
new file mode 100644
index 0000000..04678d6
--- /dev/null
+++ b/src/ast/workgroup_decoration.h
@@ -0,0 +1,65 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_AST_WORKGROUP_DECORATION_H_
+#define SRC_AST_WORKGROUP_DECORATION_H_
+
+#include <stddef.h>
+
+#include <tuple>
+
+#include "src/ast/function_decoration.h"
+
+namespace tint {
+namespace ast {
+
+/// A workgroup decoration
+class WorkgroupDecoration : public FunctionDecoration {
+ public:
+  /// constructor
+  /// @param x the workgroup x dimension size
+  explicit WorkgroupDecoration(uint32_t x);
+  /// constructor
+  /// @param x the workgroup x dimension size
+  /// @param y the workgroup x dimension size
+  WorkgroupDecoration(uint32_t x, uint32_t y);
+  /// constructor
+  /// @param x the workgroup x dimension size
+  /// @param y the workgroup x dimension size
+  /// @param z the workgroup x dimension size
+  WorkgroupDecoration(uint32_t x, uint32_t y, uint32_t z);
+  ~WorkgroupDecoration() override;
+
+  /// @returns true if this is a workgroup decoration
+  bool IsWorkgroup() const override;
+
+  /// @returns the workgroup dimensions
+  std::tuple<uint32_t, uint32_t, uint32_t> values() const {
+    return {x_, y_, z_};
+  }
+
+  /// Outputs the decoration to the given stream
+  /// @param out the stream to output too
+  void to_str(std::ostream& out) const override;
+
+ private:
+  uint32_t x_ = 1;
+  uint32_t y_ = 1;
+  uint32_t z_ = 1;
+};
+
+}  // namespace ast
+}  // namespace tint
+
+#endif  // SRC_AST_WORKGROUP_DECORATION_H_
diff --git a/src/ast/workgroup_decoration_test.cc b/src/ast/workgroup_decoration_test.cc
new file mode 100644
index 0000000..750d351
--- /dev/null
+++ b/src/ast/workgroup_decoration_test.cc
@@ -0,0 +1,74 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/ast/workgroup_decoration.h"
+
+#include <sstream>
+
+#include "gtest/gtest.h"
+
+namespace tint {
+namespace ast {
+namespace {
+
+using WorkgroupDecorationTest = testing::Test;
+
+TEST_F(WorkgroupDecorationTest, Creation_1param) {
+  WorkgroupDecoration d{2};
+  uint32_t x = 0;
+  uint32_t y = 0;
+  uint32_t z = 0;
+  std::tie(x, y, z) = d.values();
+  EXPECT_EQ(x, 2u);
+  EXPECT_EQ(y, 1u);
+  EXPECT_EQ(z, 1u);
+}
+TEST_F(WorkgroupDecorationTest, Creation_2param) {
+  WorkgroupDecoration d{2, 4};
+  uint32_t x = 0;
+  uint32_t y = 0;
+  uint32_t z = 0;
+  std::tie(x, y, z) = d.values();
+  EXPECT_EQ(x, 2u);
+  EXPECT_EQ(y, 4u);
+  EXPECT_EQ(z, 1u);
+}
+
+TEST_F(WorkgroupDecorationTest, Creation_3param) {
+  WorkgroupDecoration d{2, 4, 6};
+  uint32_t x = 0;
+  uint32_t y = 0;
+  uint32_t z = 0;
+  std::tie(x, y, z) = d.values();
+  EXPECT_EQ(x, 2u);
+  EXPECT_EQ(y, 4u);
+  EXPECT_EQ(z, 6u);
+}
+
+TEST_F(WorkgroupDecorationTest, Is) {
+  WorkgroupDecoration d{2, 4, 6};
+  EXPECT_TRUE(d.IsWorkgroup());
+}
+
+TEST_F(WorkgroupDecorationTest, ToStr) {
+  WorkgroupDecoration d{2, 4, 6};
+  std::ostringstream out;
+  d.to_str(out);
+  EXPECT_EQ(out.str(), R"(WorkgroupDecoration{2 4 6}
+)");
+}
+
+}  // namespace
+}  // namespace ast
+}  // namespace tint
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 43270f9..d34e572 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -1428,9 +1428,12 @@
   }
 
   if (ep->stage() == ast::PipelineStage::kCompute) {
-    // TODO(dsinclair): When we have a way to set the thread group size this
-    // should be updated.
-    out << "[numthreads(1, 1, 1)]" << std::endl;
+    uint32_t x = 0;
+    uint32_t y = 0;
+    uint32_t z = 0;
+    std::tie(x, y, z) = func->workgroup_size();
+    out << "[numthreads(" << std::to_string(x) << ", " << std::to_string(y)
+        << ", " << std::to_string(z) << ")]" << std::endl;
     make_indent(out);
   }
 
diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc
index 80f5b82..2056427 100644
--- a/src/writer/hlsl/generator_impl_function_test.cc
+++ b/src/writer/hlsl/generator_impl_function_test.cc
@@ -39,6 +39,7 @@
 #include "src/ast/type/void_type.h"
 #include "src/ast/variable.h"
 #include "src/ast/variable_decl_statement.h"
+#include "src/ast/workgroup_decoration.h"
 #include "src/context.h"
 #include "src/type_determiner.h"
 #include "src/writer/hlsl/test_helper.h"
@@ -1223,6 +1224,35 @@
 )");
 }
 
+TEST_F(HlslGeneratorImplTest_Function,
+       Emit_Function_EntryPoint_Compute_WithWorkgroup) {
+  ast::type::VoidType void_type;
+
+  ast::VariableList params;
+  auto func = std::make_unique<ast::Function>("comp_main", std::move(params),
+                                              &void_type);
+  func->add_decoration(std::make_unique<ast::WorkgroupDecoration>(2u, 4u, 6u));
+
+  auto body = std::make_unique<ast::BlockStatement>();
+  body->append(std::make_unique<ast::ReturnStatement>());
+  func->set_body(std::move(body));
+
+  mod()->AddFunction(std::move(func));
+
+  auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kCompute,
+                                              "main", "comp_main");
+  mod()->AddEntryPoint(std::move(ep));
+
+  ASSERT_TRUE(td().Determine()) << td().error();
+  ASSERT_TRUE(gen().Generate(out())) << gen().error();
+  EXPECT_EQ(result(), R"([numthreads(2, 4, 6)]
+void main() {
+  return;
+}
+
+)");
+}
+
 TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) {
   ast::type::F32Type f32;
   ast::type::ArrayType ary(&f32, 5);
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 9754fb2..879ded3 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -359,10 +359,15 @@
         spv::Op::OpExecutionMode,
         {Operand::Int(id), Operand::Int(SpvExecutionModeOriginUpperLeft)});
   } else if (ep->stage() == ast::PipelineStage::kCompute) {
-    // TODO(dsinclair): Support LocalSize other then (1, 1, 1)
+    auto* func = func_name_to_func_[ep->function_name()];
+
+    uint32_t x = 0;
+    uint32_t y = 0;
+    uint32_t z = 0;
+    std::tie(x, y, z) = func->workgroup_size();
     push_preamble(spv::Op::OpExecutionMode,
                   {Operand::Int(id), Operand::Int(SpvExecutionModeLocalSize),
-                   Operand::Int(1), Operand::Int(1), Operand::Int(1)});
+                   Operand::Int(x), Operand::Int(y), Operand::Int(z)});
   }
 
   return true;
diff --git a/src/writer/spirv/builder_entry_point_test.cc b/src/writer/spirv/builder_entry_point_test.cc
index 04176c1..85af7cf 100644
--- a/src/writer/spirv/builder_entry_point_test.cc
+++ b/src/writer/spirv/builder_entry_point_test.cc
@@ -25,6 +25,7 @@
 #include "src/ast/type/f32_type.h"
 #include "src/ast/type/void_type.h"
 #include "src/ast/variable.h"
+#include "src/ast/workgroup_decoration.h"
 #include "src/context.h"
 #include "src/type_determiner.h"
 #include "src/writer/spirv/builder.h"
@@ -264,6 +265,23 @@
 )");
 }
 
+TEST_F(BuilderTest, ExecutionModel_Compute_LocalSize_WithWorkgroup) {
+  ast::type::VoidType void_type;
+
+  ast::Function func("main", {}, &void_type);
+  func.add_decoration(std::make_unique<ast::WorkgroupDecoration>(2u, 4u, 6u));
+  ast::EntryPoint ep(ast::PipelineStage::kCompute, "main", "main");
+
+  ast::Module mod;
+  Builder b(&mod);
+  ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
+  ASSERT_TRUE(b.GenerateExecutionModes(&ep));
+
+  EXPECT_EQ(DumpInstructions(b.preamble()),
+            R"(OpExecutionMode %3 LocalSize 2 4 6
+)");
+}
+
 }  // namespace
 }  // namespace spirv
 }  // namespace writer
diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc
index 11d3690..08afc35 100644
--- a/src/writer/wgsl/generator_impl.cc
+++ b/src/writer/wgsl/generator_impl.cc
@@ -63,6 +63,7 @@
 #include "src/ast/uint_literal.h"
 #include "src/ast/unary_op_expression.h"
 #include "src/ast/variable_decl_statement.h"
+#include "src/ast/workgroup_decoration.h"
 
 namespace tint {
 namespace writer {
@@ -422,8 +423,21 @@
 }
 
 bool GeneratorImpl::EmitFunction(ast::Function* func) {
-  make_indent();
+  for (auto& deco : func->decorations()) {
+    make_indent();
+    out_ << "[[";
+    if (deco->IsWorkgroup()) {
+      uint32_t x = 0;
+      uint32_t y = 0;
+      uint32_t z = 0;
+      std::tie(x, y, z) = deco->AsWorkgroup()->values();
+      out_ << "workgroup_size(" << std::to_string(x) << ", "
+           << std::to_string(y) << ", " << std::to_string(z) << ")";
+    }
+    out_ << "]]" << std::endl;
+  }
 
+  make_indent();
   out_ << "fn " << func->name() << "(";
 
   bool first = true;
diff --git a/src/writer/wgsl/generator_impl_function_test.cc b/src/writer/wgsl/generator_impl_function_test.cc
index 0240a62..25bac24 100644
--- a/src/writer/wgsl/generator_impl_function_test.cc
+++ b/src/writer/wgsl/generator_impl_function_test.cc
@@ -20,6 +20,7 @@
 #include "src/ast/type/i32_type.h"
 #include "src/ast/type/void_type.h"
 #include "src/ast/variable.h"
+#include "src/ast/workgroup_decoration.h"
 #include "src/writer/wgsl/generator_impl.h"
 
 namespace tint {
@@ -77,6 +78,28 @@
 )");
 }
 
+TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecorations) {
+  auto body = std::make_unique<ast::BlockStatement>();
+  body->append(std::make_unique<ast::DiscardStatement>());
+  body->append(std::make_unique<ast::ReturnStatement>());
+
+  ast::type::VoidType void_type;
+  ast::Function func("my_func", {}, &void_type);
+  func.add_decoration(std::make_unique<ast::WorkgroupDecoration>(2u, 4u, 6u));
+  func.set_body(std::move(body));
+
+  GeneratorImpl g;
+  g.increment_indent();
+
+  ASSERT_TRUE(g.EmitFunction(&func));
+  EXPECT_EQ(g.result(), R"(  [[workgroup_size(2, 4, 6)]]
+  fn my_func() -> void {
+    discard;
+    return;
+  }
+)");
+}
+
 }  // namespace
 }  // namespace wgsl
 }  // namespace writer