Add Workgroup size information to EntryPoint struct
BUG=tint:257
Change-Id: Iaf03bfaeb622b7315d65e46eccfe90244bced339
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/29420
Commit-Queue: Ryan Harrison <rharrison@chromium.org>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/inspector.cc b/src/inspector.cc
index 804ae18..bdcb4df 100644
--- a/src/inspector.cc
+++ b/src/inspector.cc
@@ -27,7 +27,9 @@
std::vector<EntryPoint> result;
for (const auto& func : module_.functions()) {
if (func->IsEntryPoint()) {
- result.push_back({func->name(), func->pipeline_stage()});
+ uint32_t x, y, z;
+ std::tie(x, y, z) = func->workgroup_size();
+ result.push_back({func->name(), func->pipeline_stage(), x, y, z});
}
}
diff --git a/src/inspector.h b/src/inspector.h
index 86ef88e..f20bf42 100644
--- a/src/inspector.h
+++ b/src/inspector.h
@@ -25,11 +25,22 @@
namespace tint {
namespace inspector {
+/// Container of reflection data for an entry point in the shader.
struct EntryPoint {
/// The entry point name
std::string name;
/// The entry point stage
ast::PipelineStage stage = ast::PipelineStage::kNone;
+ /// Elements of the workgroup size tuple
+ uint32_t workgroup_size_x;
+ uint32_t workgroup_size_y;
+ uint32_t workgroup_size_z;
+
+ /// @returns the size of the workgroup in {x,y,z} format
+ std::tuple<uint32_t, uint32_t, uint32_t> workgroup_size() {
+ return std::tuple<uint32_t, uint32_t, uint32_t>(
+ workgroup_size_x, workgroup_size_y, workgroup_size_z);
+ }
};
/// Extracts information from a module
diff --git a/src/inspector_test.cc b/src/inspector_test.cc
index def6952..7bd3b4f 100644
--- a/src/inspector_test.cc
+++ b/src/inspector_test.cc
@@ -19,6 +19,7 @@
#include "src/ast/pipeline_stage.h"
#include "src/ast/stage_decoration.h"
#include "src/ast/type/void_type.h"
+#include "src/ast/workgroup_decoration.h"
#include "src/context.h"
namespace tint {
@@ -38,9 +39,15 @@
if (stage != ast::PipelineStage::kNone) {
func->add_decoration(std::make_unique<ast::StageDecoration>(stage));
}
+ last_function_ = func.get();
mod()->AddFunction(std::move(func));
}
+ void AddWorkGroupSizeToLastFunction(uint32_t x, uint32_t y, uint32_t z) {
+ last_function_->add_decoration(
+ std::make_unique<ast::WorkgroupDecoration>(x, y, z));
+ }
+
ast::Module* mod() { return mod_.get(); }
Inspector* inspector() { return inspector_.get(); }
@@ -48,6 +55,7 @@
Context ctx_;
std::unique_ptr<ast::Module> mod_;
std::unique_ptr<Inspector> inspector_;
+ ast::Function* last_function_;
};
class InspectorTest : public InspectorHelper, public testing::Test {};
@@ -110,6 +118,34 @@
EXPECT_EQ(ast::PipelineStage::kCompute, result[1].stage);
}
+TEST_F(InspectorGetEntryPointTest, DefaultWorkgroupSize) {
+ AddFunction("foo", ast::PipelineStage::kVertex);
+
+ auto result = inspector()->GetEntryPoints();
+ ASSERT_FALSE(inspector()->has_error());
+
+ ASSERT_EQ(1u, result.size());
+ uint32_t x, y, z;
+ std::tie(x, y, z) = result[0].workgroup_size();
+ EXPECT_EQ(1u, x);
+ EXPECT_EQ(1u, y);
+ EXPECT_EQ(1u, z);
+}
+
+TEST_F(InspectorGetEntryPointTest, NonDefaultWorkgroupSize) {
+ AddFunction("foo", ast::PipelineStage::kCompute);
+ AddWorkGroupSizeToLastFunction(8u, 2u, 1u);
+ auto result = inspector()->GetEntryPoints();
+ ASSERT_FALSE(inspector()->has_error());
+
+ ASSERT_EQ(1u, result.size());
+ uint32_t x, y, z;
+ std::tie(x, y, z) = result[0].workgroup_size();
+ EXPECT_EQ(8u, x);
+ EXPECT_EQ(2u, y);
+ EXPECT_EQ(1u, z);
+}
+
} // namespace
} // namespace inspector
} // namespace tint