Revert "[tint] Remove workgroup size and storage reflection"
This reverts commit 1c8567bd2af2b3ac39ba6f59457244bf9894fea3.
Reason for revert: breaks a downstream client that used the workgroup size reflection
Original change's description:
> [tint] Remove workgroup size and storage reflection
>
> This information is now reflected from the IR module during the
> backend codegen, so these fields are unused.
>
> Change-Id: Ia024074419a7dda0ef5bf71fa4fa365fec701bfd
> Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/277075
> Reviewed-by: dan sinclair <dsinclair@chromium.org>
> Commit-Queue: James Price <jrprice@google.com>
# Not skipping CQ checks because original CL landed > 1 day ago.
Change-Id: Id0748dceb359174c4359cac6b30ccee74c287d03
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/278876
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: David Neto <dneto@google.com>
Auto-Submit: James Price <jrprice@google.com>
Commit-Queue: David Neto <dneto@google.com>
diff --git a/src/tint/cmd/common/helper.cc b/src/tint/cmd/common/helper.cc
index ee679ca..f4e229e 100644
--- a/src/tint/cmd/common/helper.cc
+++ b/src/tint/cmd/common/helper.cc
@@ -272,6 +272,12 @@
std::cout << "Entry Point = " << entry_point.name << " ("
<< EntryPointStageToString(entry_point.stage) << ")\n";
+ if (entry_point.workgroup_size) {
+ std::cout << " Workgroup Size (" << entry_point.workgroup_size->x << ", "
+ << entry_point.workgroup_size->y << ", " << entry_point.workgroup_size->z
+ << ")\n";
+ }
+
if (!entry_point.input_variables.empty()) {
std::cout << " Input Variables:\n";
diff --git a/src/tint/cmd/info/main.cc b/src/tint/cmd/info/main.cc
index 3bdc894..2e563f3 100644
--- a/src/tint/cmd/info/main.cc
+++ b/src/tint/cmd/info/main.cc
@@ -132,6 +132,12 @@
<< "\"stage\": \"" << tint::cmd::EntryPointStageToString(entry_point.stage)
<< "\",\n";
+ if (entry_point.workgroup_size) {
+ std::cout << "\"workgroup_size\": [";
+ std::cout << entry_point.workgroup_size->x << ", " << entry_point.workgroup_size->y
+ << ", " << entry_point.workgroup_size->z << "],\n";
+ }
+
std::cout << "\"input_variables\": [";
bool input_first = true;
for (const auto& var : entry_point.input_variables) {
diff --git a/src/tint/lang/wgsl/inspector/entry_point.h b/src/tint/lang/wgsl/inspector/entry_point.h
index 59ba1cc..a17de12 100644
--- a/src/tint/lang/wgsl/inspector/entry_point.h
+++ b/src/tint/lang/wgsl/inspector/entry_point.h
@@ -163,6 +163,12 @@
std::string name;
/// The entry point stage
PipelineStage stage;
+ /// The workgroup size. If PipelineStage is kCompute and this holds no value, then the workgroup
+ /// size is derived from an override-expression. In this situation you first need to run the
+ /// SubstituteOverride transform before using the inspector.
+ std::optional<WorkgroupSize> workgroup_size;
+ /// The total size in bytes of all Workgroup storage-class storage accessed via the entry point.
+ uint32_t workgroup_storage_size = 0;
/// The total size in bytes of all immediate variables accessed by the entry point.
uint32_t immediate_data_size = 0;
/// List of the input variable accessed via this entry point.
diff --git a/src/tint/lang/wgsl/inspector/inspector.cc b/src/tint/lang/wgsl/inspector/inspector.cc
index 31cf03d..654bbf7 100644
--- a/src/tint/lang/wgsl/inspector/inspector.cc
+++ b/src/tint/lang/wgsl/inspector/inspector.cc
@@ -68,6 +68,7 @@
#include "src/tint/lang/wgsl/sem/type_expression.h"
#include "src/tint/lang/wgsl/sem/variable.h"
#include "src/tint/utils/containers/unique_vector.h"
+#include "src/tint/utils/math/math.h"
#include "src/tint/utils/rtti/switch.h"
#include "src/tint/utils/text/string.h"
@@ -310,7 +311,16 @@
switch (func->PipelineStage()) {
case ast::PipelineStage::kCompute: {
entry_point.stage = PipelineStage::kCompute;
+ entry_point.workgroup_storage_size = ComputeWorkgroupStorageSize(func);
+
+ auto wgsize = sem->WorkgroupSize();
+ if (wgsize[0].has_value() && wgsize[1].has_value() && wgsize[2].has_value()) {
+ entry_point.workgroup_size = {wgsize[0].value(), wgsize[1].value(),
+ wgsize[2].value()};
+ }
+
entry_point.uses_subgroup_matrix = UsesSubgroupMatrix(sem);
+
break;
}
case ast::PipelineStage::kFragment: {
@@ -882,6 +892,26 @@
return {interpolation_type, sampling_type};
}
+uint32_t Inspector::ComputeWorkgroupStorageSize(const ast::Function* func) const {
+ uint32_t total_size = 0;
+ auto* func_sem = program_.Sem().Get(func);
+ for (const sem::Variable* var : func_sem->TransitivelyReferencedGlobals()) {
+ if (var->AddressSpace() == core::AddressSpace::kWorkgroup) {
+ auto* ty = var->Type()->UnwrapRef();
+ uint32_t align = ty->Align();
+ uint32_t size = ty->Size();
+
+ // This essentially matches std430 layout rules from GLSL, which are in
+ // turn specified as an upper bound for Vulkan layout sizing. Since D3D
+ // and Metal are even less specific, we assume Vulkan behavior as a
+ // good-enough approximation everywhere.
+ total_size += tint::RoundUp(16u, tint::RoundUp(align, size));
+ }
+ }
+
+ return total_size;
+}
+
uint32_t Inspector::ComputeImmediateDataSize(const ast::Function* func) const {
uint32_t size = 0;
auto* func_sem = program_.Sem().Get(func);
diff --git a/src/tint/lang/wgsl/inspector/inspector.h b/src/tint/lang/wgsl/inspector/inspector.h
index 3d441c0..8a3a3f9 100644
--- a/src/tint/lang/wgsl/inspector/inspector.h
+++ b/src/tint/lang/wgsl/inspector/inspector.h
@@ -206,6 +206,10 @@
VectorRef<const ast::Attribute*> attributes) const;
/// @param func the root function of the callgraph to consider for the computation.
+ /// @returns the total size in bytes of all Workgroup storage-class storage accessed via func.
+ uint32_t ComputeWorkgroupStorageSize(const ast::Function* func) const;
+
+ /// @param func the root function of the callgraph to consider for the computation.
/// @returns the total size in bytes of all immediate data variables accessed via func.
uint32_t ComputeImmediateDataSize(const ast::Function* func) const;
diff --git a/src/tint/lang/wgsl/inspector/inspector_test.cc b/src/tint/lang/wgsl/inspector/inspector_test.cc
index 4d45a58..957406c 100644
--- a/src/tint/lang/wgsl/inspector/inspector_test.cc
+++ b/src/tint/lang/wgsl/inspector/inspector_test.cc
@@ -168,6 +168,23 @@
EXPECT_EQ(PipelineStage::kFragment, result[1].stage);
}
+TEST_F(InspectorGetEntryPointTest, DefaultWorkgroupSize) {
+ auto* src = R"(
+@compute @workgroup_size(8i, 2i, 1i) fn foo() {}
+)";
+ Inspector& inspector = Initialize(src);
+
+ auto result = inspector.GetEntryPoints();
+ ASSERT_FALSE(inspector.has_error()) << inspector.error();
+
+ ASSERT_EQ(1u, result.size());
+ auto workgroup_size = result[0].workgroup_size;
+ ASSERT_TRUE(workgroup_size.has_value());
+ EXPECT_EQ(8u, workgroup_size->x);
+ EXPECT_EQ(2u, workgroup_size->y);
+ EXPECT_EQ(1u, workgroup_size->z);
+}
+
// Test that immediate_data_size is zero if there are no immediate data.
TEST_F(InspectorGetEntryPointTest, ImmediateDataSizeNone) {
auto* src = R"(
@@ -246,6 +263,134 @@
EXPECT_EQ(4u, result[0].immediate_data_size);
}
+TEST_F(InspectorGetEntryPointTest, NonDefaultWorkgroupSize) {
+ auto* src = R"(
+@compute @workgroup_size(8i, 2i, 1i)
+fn foo() {}
+)";
+ Inspector& inspector = Initialize(src);
+
+ auto result = inspector.GetEntryPoints();
+ ASSERT_FALSE(inspector.has_error()) << inspector.error();
+
+ ASSERT_EQ(1u, result.size());
+ auto workgroup_size = result[0].workgroup_size;
+ ASSERT_TRUE(workgroup_size.has_value());
+ EXPECT_EQ(8u, workgroup_size->x);
+ EXPECT_EQ(2u, workgroup_size->y);
+ EXPECT_EQ(1u, workgroup_size->z);
+}
+
+TEST_F(InspectorGetEntryPointTest, WorkgroupStorageSizeEmpty) {
+ auto* src = R"(
+@compute @workgroup_size(1i)
+fn ep_func() {}
+)";
+ Inspector& inspector = Initialize(src);
+ auto result = inspector.GetEntryPoints();
+ ASSERT_FALSE(inspector.has_error()) << inspector.error();
+
+ ASSERT_EQ(1u, result.size());
+ EXPECT_EQ(0u, result[0].workgroup_storage_size);
+}
+
+TEST_F(InspectorGetEntryPointTest, WorkgroupStorageSizeSimple) {
+ auto* src = R"(
+var<workgroup> wg_f32: f32;
+var<workgroup> wg_i32: i32;
+
+fn f32_func() { _ = wg_f32; }
+fn i32_func() { _ = wg_i32; }
+
+@compute @workgroup_size(1i)
+fn ep_func() {
+ f32_func();
+ i32_func();
+}
+)";
+ Inspector& inspector = Initialize(src);
+ auto result = inspector.GetEntryPoints();
+ ASSERT_FALSE(inspector.has_error()) << inspector.error();
+
+ ASSERT_EQ(1u, result.size());
+ EXPECT_EQ(32u, result[0].workgroup_storage_size);
+}
+
+TEST_F(InspectorGetEntryPointTest, WorkgroupStorageSizeCompoundTypes) {
+ auto* src = R"(
+// This struct should occupy 68 bytes.
+struct WgStruct {
+ a: i32,
+ b: array<i32, 16>,
+}
+var<workgroup> wg_struct_var: WgStruct;
+
+fn wg_struct_func() { _ = wg_struct_var.a; }
+
+// Plus another 4 bytes from this other workgroup-class f32.
+var<workgroup> wg_f32: f32;
+fn f32_func() { _ = wg_f32; }
+
+@compute @workgroup_size(1i)
+fn ep_func() {
+ wg_struct_func();
+ f32_func();
+}
+)";
+ Inspector& inspector = Initialize(src);
+ auto result = inspector.GetEntryPoints();
+ ASSERT_FALSE(inspector.has_error()) << inspector.error();
+
+ ASSERT_EQ(1u, result.size());
+ EXPECT_EQ(96u, result[0].workgroup_storage_size);
+}
+
+TEST_F(InspectorGetEntryPointTest, WorkgroupStorageSizeAlignmentPadding) {
+ auto* src = R"(
+// vec3<f32> has an alignment of 16 but a size of 12. We leverage this to test
+// that our padded size calculation for workgroup storage is accurate.
+var<workgroup> wg_vec3: vec3f;
+
+fn wg_func() { _ = wg_vec3; }
+
+@compute @workgroup_size(1i)
+fn ep_func() {
+ wg_func();
+}
+)";
+ Inspector& inspector = Initialize(src);
+ auto result = inspector.GetEntryPoints();
+ ASSERT_FALSE(inspector.has_error()) << inspector.error();
+
+ ASSERT_EQ(1u, result.size());
+ EXPECT_EQ(16u, result[0].workgroup_storage_size);
+}
+
+TEST_F(InspectorGetEntryPointTest, WorkgroupStorageSizeStructAlignment) {
+ auto* src = R"(
+// Per WGSL spec, a struct's size is the offset its last member plus the size
+// of its last member, rounded up to the alignment of its largest member. So
+// here the struct is expected to occupy 1024 bytes of workgroup storage.
+struct WgStruct {
+ @align(1024i) a: f32,
+}
+var<workgroup> wg_struct_var: WgStruct;
+
+fn wg_struct_func() { _ = wg_struct_var.a; }
+
+@compute @workgroup_size(1i)
+fn ep_func() {
+ wg_struct_func();
+}
+)";
+ Inspector& inspector = Initialize(src);
+ auto result = inspector.GetEntryPoints();
+ ASSERT_FALSE(inspector.has_error()) << inspector.error();
+
+ ASSERT_EQ(1u, result.size());
+ EXPECT_EQ(1024u, result[0].workgroup_storage_size);
+}
+
TEST_F(InspectorGetEntryPointTest, NoInOutVariables) {
auto* src = R"(
fn func() {}