[hlsl-writer] Emit numthreads for compute shaders.
This CL adds the numthreads annotation when emitting compute shaders.
Bug: tint:7
Change-Id: Ie0f47adfca0a0684f701f280958163b3da0019b4
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/27480
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 4ed643d..bf66a22 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -1354,6 +1354,13 @@
return false;
}
+ if (ep->stage() == ast::PipelineStage::kCompute) {
+ // TODO(dsinclair): When we have a way to set the thread group size this
+ // should be updated.
+ out << "[numthreads(1, 1, 1)]" << std::endl;
+ make_indent(out);
+ }
+
auto outdata = ep_name_to_out_data_.find(current_ep_name_);
bool has_outdata = outdata != ep_name_to_out_data_.end();
if (has_outdata) {
diff --git a/src/writer/hlsl/generator_impl_entry_point_test.cc b/src/writer/hlsl/generator_impl_entry_point_test.cc
index 2177806..189c2ab 100644
--- a/src/writer/hlsl/generator_impl_entry_point_test.cc
+++ b/src/writer/hlsl/generator_impl_entry_point_test.cc
@@ -19,6 +19,7 @@
#include "src/ast/location_decoration.h"
#include "src/ast/member_accessor_expression.h"
#include "src/ast/module.h"
+#include "src/ast/return_statement.h"
#include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h"
#include "src/ast/type/vector_type.h"
diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc
index 51b6830..d79394b 100644
--- a/src/writer/hlsl/generator_impl_function_test.cc
+++ b/src/writer/hlsl/generator_impl_function_test.cc
@@ -1139,10 +1139,10 @@
TEST_F(HlslGeneratorImplTest_Function, Emit_Function_EntryPoint_WithName) {
ast::type::VoidType void_type;
- auto func = std::make_unique<ast::Function>("comp_main", ast::VariableList{},
+ auto func = std::make_unique<ast::Function>("frag_main", ast::VariableList{},
&void_type);
- auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kCompute,
- "my_main", "comp_main");
+ auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment,
+ "my_main", "frag_main");
mod()->AddFunction(std::move(func));
mod()->AddEntryPoint(std::move(ep));
@@ -1158,10 +1158,10 @@
Emit_Function_EntryPoint_WithNameCollision) {
ast::type::VoidType void_type;
- auto func = std::make_unique<ast::Function>("comp_main", ast::VariableList{},
+ auto func = std::make_unique<ast::Function>("frag_main", ast::VariableList{},
&void_type);
- auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kCompute,
- "GeometryShader", "comp_main");
+ auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment,
+ "GeometryShader", "frag_main");
mod()->AddFunction(std::move(func));
mod()->AddEntryPoint(std::move(ep));
@@ -1173,6 +1173,33 @@
)");
}
+TEST_F(HlslGeneratorImplTest_Function, Emit_Function_EntryPoint_Compute) {
+ ast::type::VoidType void_type;
+
+ ast::VariableList params;
+ auto func = std::make_unique<ast::Function>("comp_main", std::move(params),
+ &void_type);
+
+ auto body = std::make_unique<ast::BlockStatement>();
+ body->append(std::make_unique<ast::ReturnStatement>());
+ func->set_body(std::move(body));
+
+ mod()->AddFunction(std::move(func));
+
+ auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kCompute,
+ "main", "comp_main");
+ mod()->AddEntryPoint(std::move(ep));
+
+ ASSERT_TRUE(td().Determine()) << td().error();
+ ASSERT_TRUE(gen().Generate(out())) << gen().error();
+ EXPECT_EQ(result(), R"([numthreads(1, 1, 1)]
+void main() {
+ return;
+}
+
+)");
+}
+
TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) {
ast::type::F32Type f32;
ast::type::ArrayType ary(&f32, 5);