[hlsl-writer] Emit numthreads for compute shaders.

This CL adds the numthreads annotation when emitting compute shaders.

Bug: tint:7
Change-Id: Ie0f47adfca0a0684f701f280958163b3da0019b4
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/27480
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 4ed643d..bf66a22 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -1354,6 +1354,13 @@
     return false;
   }
 
+  if (ep->stage() == ast::PipelineStage::kCompute) {
+    // TODO(dsinclair): When we have a way to set the thread group size this
+    // should be updated.
+    out << "[numthreads(1, 1, 1)]" << std::endl;
+    make_indent(out);
+  }
+
   auto outdata = ep_name_to_out_data_.find(current_ep_name_);
   bool has_outdata = outdata != ep_name_to_out_data_.end();
   if (has_outdata) {
diff --git a/src/writer/hlsl/generator_impl_entry_point_test.cc b/src/writer/hlsl/generator_impl_entry_point_test.cc
index 2177806..189c2ab 100644
--- a/src/writer/hlsl/generator_impl_entry_point_test.cc
+++ b/src/writer/hlsl/generator_impl_entry_point_test.cc
@@ -19,6 +19,7 @@
 #include "src/ast/location_decoration.h"
 #include "src/ast/member_accessor_expression.h"
 #include "src/ast/module.h"
+#include "src/ast/return_statement.h"
 #include "src/ast/type/f32_type.h"
 #include "src/ast/type/i32_type.h"
 #include "src/ast/type/vector_type.h"
diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc
index 51b6830..d79394b 100644
--- a/src/writer/hlsl/generator_impl_function_test.cc
+++ b/src/writer/hlsl/generator_impl_function_test.cc
@@ -1139,10 +1139,10 @@
 TEST_F(HlslGeneratorImplTest_Function, Emit_Function_EntryPoint_WithName) {
   ast::type::VoidType void_type;
 
-  auto func = std::make_unique<ast::Function>("comp_main", ast::VariableList{},
+  auto func = std::make_unique<ast::Function>("frag_main", ast::VariableList{},
                                               &void_type);
-  auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kCompute,
-                                              "my_main", "comp_main");
+  auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment,
+                                              "my_main", "frag_main");
 
   mod()->AddFunction(std::move(func));
   mod()->AddEntryPoint(std::move(ep));
@@ -1158,10 +1158,10 @@
        Emit_Function_EntryPoint_WithNameCollision) {
   ast::type::VoidType void_type;
 
-  auto func = std::make_unique<ast::Function>("comp_main", ast::VariableList{},
+  auto func = std::make_unique<ast::Function>("frag_main", ast::VariableList{},
                                               &void_type);
-  auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kCompute,
-                                              "GeometryShader", "comp_main");
+  auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment,
+                                              "GeometryShader", "frag_main");
 
   mod()->AddFunction(std::move(func));
   mod()->AddEntryPoint(std::move(ep));
@@ -1173,6 +1173,33 @@
 )");
 }
 
+TEST_F(HlslGeneratorImplTest_Function, Emit_Function_EntryPoint_Compute) {
+  ast::type::VoidType void_type;
+
+  ast::VariableList params;
+  auto func = std::make_unique<ast::Function>("comp_main", std::move(params),
+                                              &void_type);
+
+  auto body = std::make_unique<ast::BlockStatement>();
+  body->append(std::make_unique<ast::ReturnStatement>());
+  func->set_body(std::move(body));
+
+  mod()->AddFunction(std::move(func));
+
+  auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kCompute,
+                                              "main", "comp_main");
+  mod()->AddEntryPoint(std::move(ep));
+
+  ASSERT_TRUE(td().Determine()) << td().error();
+  ASSERT_TRUE(gen().Generate(out())) << gen().error();
+  EXPECT_EQ(result(), R"([numthreads(1, 1, 1)]
+void main() {
+  return;
+}
+
+)");
+}
+
 TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) {
   ast::type::F32Type f32;
   ast::type::ArrayType ary(&f32, 5);