inspector: reflect workgroup storage size
This reflects the total size of all workgroup storage-class variables
referenced transitively by an entry point.
Bug: tint:919
Change-Id: If3a217fea5a875ac18db6de1579f004e368fbb7b
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/57740
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ken Rockot <rockot@google.com>
diff --git a/src/inspector/inspector.cc b/src/inspector/inspector.cc
index f8b3b59..07e942a 100644
--- a/src/inspector/inspector.cc
+++ b/src/inspector/inspector.cc
@@ -42,6 +42,7 @@
#include "src/sem/variable.h"
#include "src/sem/vector_type.h"
#include "src/sem/void_type.h"
+#include "src/utils/math.h"
namespace tint {
namespace inspector {
@@ -534,6 +535,31 @@
return it->second;
}
+uint32_t Inspector::GetWorkgroupStorageSize(const std::string& entry_point) {
+ auto* func = FindEntryPointByName(entry_point);
+ if (!func) {
+ return 0;
+ }
+
+ uint32_t total_size = 0;
+ auto* func_sem = program_->Sem().Get(func);
+ for (const sem::Variable* var : func_sem->ReferencedModuleVariables()) {
+ if (var->StorageClass() == ast::StorageClass::kWorkgroup) {
+ uint32_t align = 0;
+ uint32_t size = 0;
+ var->Type()->UnwrapRef()->GetDefaultAlignAndSize(align, size);
+
+ // This essentially matches std430 layout rules from GLSL, which are in
+ // turn specified as an upper bound for Vulkan layout sizing. Since D3D
+ // and Metal are even less specific, we assume Vulkan behavior as a
+ // good-enough approximation everywhere.
+ total_size += utils::RoundUp(align, size);
+ }
+ }
+
+ return total_size;
+}
+
ast::Function* Inspector::FindEntryPointByName(const std::string& name) {
auto* func = program_->AST().Functions().Find(program_->Symbols().Get(name));
if (!func) {
diff --git a/src/inspector/inspector.h b/src/inspector/inspector.h
index b7228e3..a386a10 100644
--- a/src/inspector/inspector.h
+++ b/src/inspector/inspector.h
@@ -132,6 +132,11 @@
std::vector<SamplerTexturePair> GetSamplerTextureUses(
const std::string& entry_point);
+ /// @param entry_point name of the entry point to get information about.
+ /// @returns the total size in bytes of all Workgroup storage-class storage
+ /// referenced transitively by the entry point.
+ uint32_t GetWorkgroupStorageSize(const std::string& entry_point);
+
private:
const Program* program_;
std::string error_;
diff --git a/src/inspector/inspector_test.cc b/src/inspector/inspector_test.cc
index 99e5917..7d7740a 100644
--- a/src/inspector/inspector_test.cc
+++ b/src/inspector/inspector_test.cc
@@ -134,6 +134,9 @@
class InspectorGetSamplerTextureUsesTest : public InspectorBuilder,
public testing::Test {};
+class InspectorGetWorkgroupStorageSizeTest : public InspectorBuilder,
+ public testing::Test {};
+
TEST_F(InspectorGetEntryPointTest, NoFunctions) {
Inspector& inspector = Build();
@@ -549,7 +552,7 @@
TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByEntryPoint) {
AddOverridableConstantWithoutID<float>("foo", ty.f32(), nullptr);
- MakeConstReferenceBodyFunction(
+ MakePlainGlobalReferenceBodyFunction(
"ep_func", "foo", ty.f32(),
{Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
@@ -564,7 +567,7 @@
TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByCallee) {
AddOverridableConstantWithoutID<float>("foo", ty.f32(), nullptr);
- MakeConstReferenceBodyFunction("callee_func", "foo", ty.f32(), {});
+ MakePlainGlobalReferenceBodyFunction("callee_func", "foo", ty.f32(), {});
MakeCallerBodyFunction(
"ep_func", {"callee_func"},
{Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
@@ -581,7 +584,7 @@
TEST_F(InspectorGetEntryPointTest, OverridableConstantSomeReferenced) {
AddOverridableConstantWithID<float>("foo", 1, ty.f32(), nullptr);
AddOverridableConstantWithID<float>("bar", 2, ty.f32(), nullptr);
- MakeConstReferenceBodyFunction("callee_func", "foo", ty.f32(), {});
+ MakePlainGlobalReferenceBodyFunction("callee_func", "foo", ty.f32(), {});
MakeCallerBodyFunction(
"ep_func", {"callee_func"},
{Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
@@ -2397,6 +2400,93 @@
EXPECT_EQ(0u, result[0].texture_binding_point.binding);
}
+TEST_F(InspectorGetWorkgroupStorageSizeTest, Empty) {
+ MakeEmptyBodyFunction("ep_func",
+ ast::DecorationList{Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(1)});
+ Inspector& inspector = Build();
+ EXPECT_EQ(0u, inspector.GetWorkgroupStorageSize("ep_func"));
+}
+
+TEST_F(InspectorGetWorkgroupStorageSizeTest, Simple) {
+ AddWorkgroupStorage("wg_f32", ty.f32());
+ MakePlainGlobalReferenceBodyFunction("f32_func", "wg_f32", ty.f32(), {});
+
+ MakeCallerBodyFunction("ep_func", {"f32_func"},
+ ast::DecorationList{
+ Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(1),
+ });
+
+ Inspector& inspector = Build();
+ EXPECT_EQ(4u, inspector.GetWorkgroupStorageSize("ep_func"));
+}
+
+TEST_F(InspectorGetWorkgroupStorageSizeTest, CompoundTypes) {
+ // This struct should occupy 68 bytes. 4 from the i32 field, and another 64
+ // from the 4-element array with 16-byte stride.
+ ast::Struct* wg_struct_type = MakeStructType(
+ "WgStruct", {ty.i32(), ty.array(ty.i32(), 4, /*stride=*/16)},
+ /*is_block=*/false);
+ AddWorkgroupStorage("wg_struct_var", ty.Of(wg_struct_type));
+ MakeStructVariableReferenceBodyFunction("wg_struct_func", "wg_struct_var",
+ {{0, ty.i32()}});
+
+ // Plus another 4 bytes from this other workgroup-class f32.
+ AddWorkgroupStorage("wg_f32", ty.f32());
+ MakePlainGlobalReferenceBodyFunction("f32_func", "wg_f32", ty.f32(), {});
+
+ MakeCallerBodyFunction("ep_func", {"wg_struct_func", "f32_func"},
+ ast::DecorationList{
+ Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(1),
+ });
+
+ Inspector& inspector = Build();
+ EXPECT_EQ(72u, inspector.GetWorkgroupStorageSize("ep_func"));
+}
+
+TEST_F(InspectorGetWorkgroupStorageSizeTest, AlignmentPadding) {
+ // vec3<f32> has an alignment of 16 but a size of 12. We leverage this to test
+ // that our padded size calculation for workgroup storage is accurate.
+ AddWorkgroupStorage("wg_vec3", ty.vec3<f32>());
+ MakePlainGlobalReferenceBodyFunction("wg_func", "wg_vec3", ty.vec3<f32>(),
+ {});
+
+ MakeCallerBodyFunction("ep_func", {"wg_func"},
+ ast::DecorationList{
+ Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(1),
+ });
+
+ Inspector& inspector = Build();
+ EXPECT_EQ(16u, inspector.GetWorkgroupStorageSize("ep_func"));
+}
+
+TEST_F(InspectorGetWorkgroupStorageSizeTest, StructAlignment) {
+ // Per WGSL spec, a struct's size is the offset its last member plus the size
+ // of its last member, rounded up to the alignment of its largest member. So
+ // here the struct is expected to occupy 1024 bytes of workgroup storage.
+ ast::Struct* wg_struct_type = MakeStructTypeFromMembers(
+ "WgStruct",
+ {MakeStructMember(0, ty.f32(),
+ {create<ast::StructMemberAlignDecoration>(1024)})},
+ /*is_block=*/false);
+
+ AddWorkgroupStorage("wg_struct_var", ty.Of(wg_struct_type));
+ MakeStructVariableReferenceBodyFunction("wg_struct_func", "wg_struct_var",
+ {{0, ty.f32()}});
+
+ MakeCallerBodyFunction("ep_func", {"wg_struct_func"},
+ ast::DecorationList{
+ Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(1),
+ });
+
+ Inspector& inspector = Build();
+ EXPECT_EQ(1024u, inspector.GetWorkgroupStorageSize("ep_func"));
+}
+
} // namespace
} // namespace inspector
} // namespace tint
diff --git a/src/inspector/test_inspector_builder.cc b/src/inspector/test_inspector_builder.cc
index 922f342..5667acb 100644
--- a/src/inspector/test_inspector_builder.cc
+++ b/src/inspector/test_inspector_builder.cc
@@ -60,7 +60,7 @@
return Structure(name, members);
}
-ast::Function* InspectorBuilder::MakeConstReferenceBodyFunction(
+ast::Function* InspectorBuilder::MakePlainGlobalReferenceBodyFunction(
std::string func,
std::string var,
ast::Type* type,
@@ -93,15 +93,27 @@
bool is_block) {
ast::StructMemberList members;
for (auto* type : member_types) {
- members.push_back(Member(StructMemberName(members.size(), type), type));
+ members.push_back(MakeStructMember(members.size(), type, {}));
}
+ return MakeStructTypeFromMembers(name, std::move(members), is_block);
+}
+ast::Struct* InspectorBuilder::MakeStructTypeFromMembers(
+ const std::string& name,
+ ast::StructMemberList members,
+ bool is_block) {
ast::DecorationList decos;
if (is_block) {
decos.push_back(create<ast::StructBlockDecoration>());
}
+ return Structure(name, std::move(members), decos);
+}
- return Structure(name, members, decos);
+ast::StructMember* InspectorBuilder::MakeStructMember(
+ size_t index,
+ ast::Type* type,
+ ast::DecorationList decorations) {
+ return Member(StructMemberName(index, type), type, std::move(decorations));
}
ast::Struct* InspectorBuilder::MakeUniformBufferType(
@@ -128,6 +140,11 @@
});
}
+void InspectorBuilder::AddWorkgroupStorage(const std::string& name,
+ ast::Type* type) {
+ Global(name, type, ast::StorageClass::kWorkgroup);
+}
+
void InspectorBuilder::AddStorageBuffer(const std::string& name,
ast::Type* type,
ast::Access access,
diff --git a/src/inspector/test_inspector_builder.h b/src/inspector/test_inspector_builder.h
index 3d1353d..7697601 100644
--- a/src/inspector/test_inspector_builder.h
+++ b/src/inspector/test_inspector_builder.h
@@ -139,13 +139,14 @@
});
}
- /// Generates a function that references module constant
+ /// Generates a function that references module-scoped, plain-typed constant
+ /// or variable.
/// @param func name of the function created
/// @param var name of the constant to be reference
/// @param type type of the const being referenced
/// @param decorations the function decorations
/// @returns a function object
- ast::Function* MakeConstReferenceBodyFunction(
+ ast::Function* MakePlainGlobalReferenceBodyFunction(
std::string func,
std::string var,
ast::Type* type,
@@ -172,6 +173,24 @@
std::vector<ast::Type*> member_types,
bool is_block);
+ /// Generates a struct type from a list of member nodes.
+ /// @param name name for the struct type
+ /// @param members a vector of members
+ /// @param is_block whether or not to decorate as a Block
+ /// @returns a struct type
+ ast::Struct* MakeStructTypeFromMembers(const std::string& name,
+ ast::StructMemberList members,
+ bool is_block);
+
+ /// Generates a struct member with a specified index and type.
+ /// @param index index of the field within the struct
+ /// @param type the type of the member field
+ /// @param decorations a list of decorations to apply to the member field
+ /// @returns a struct member
+ ast::StructMember* MakeStructMember(size_t index,
+ ast::Type* type,
+ ast::DecorationList decorations);
+
/// Generates types appropriate for using in an uniform buffer
/// @param name name for the type
/// @param member_types a vector of member types
@@ -197,6 +216,11 @@
uint32_t group,
uint32_t binding);
+ /// Adds a workgroup storage variable to the program
+ /// @param name the name of the variable
+ /// @param type the type of the variable
+ void AddWorkgroupStorage(const std::string& name, ast::Type* type);
+
/// Adds a storage buffer variable to the program
/// @param name the name of the variable
/// @param type the type to use
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 20c2739..492026c 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -723,7 +723,7 @@
auto required_alignment_of = [&](const sem::Type* ty) {
uint32_t actual_align = 0;
uint32_t actual_size = 0;
- DefaultAlignAndSize(ty, actual_align, actual_size);
+ ty->GetDefaultAlignAndSize(actual_align, actual_size);
uint32_t required_align = actual_align;
if (is_uniform_struct_or_array(ty)) {
required_align = utils::RoundUp(16u, actual_align);
@@ -3750,69 +3750,6 @@
}
}
-bool Resolver::DefaultAlignAndSize(const sem::Type* ty,
- uint32_t& align,
- uint32_t& size) {
- static constexpr uint32_t vector_size[] = {
- /* padding */ 0,
- /* padding */ 0,
- /*vec2*/ 8,
- /*vec3*/ 12,
- /*vec4*/ 16,
- };
- static constexpr uint32_t vector_align[] = {
- /* padding */ 0,
- /* padding */ 0,
- /*vec2*/ 8,
- /*vec3*/ 16,
- /*vec4*/ 16,
- };
-
- if (ty->is_scalar()) {
- // Note: Also captures booleans, but these are not host-shareable.
- align = 4;
- size = 4;
- return true;
- }
- if (auto* vec = ty->As<sem::Vector>()) {
- if (vec->size() < 2 || vec->size() > 4) {
- TINT_UNREACHABLE(Resolver, diagnostics_)
- << "Invalid vector size: vec" << vec->size();
- return false;
- }
- align = vector_align[vec->size()];
- size = vector_size[vec->size()];
- return true;
- }
- if (auto* mat = ty->As<sem::Matrix>()) {
- if (mat->columns() < 2 || mat->columns() > 4 || mat->rows() < 2 ||
- mat->rows() > 4) {
- TINT_UNREACHABLE(Resolver, diagnostics_)
- << "Invalid matrix size: mat" << mat->columns() << "x" << mat->rows();
- return false;
- }
- align = vector_align[mat->rows()];
- size = vector_align[mat->rows()] * mat->columns();
- return true;
- }
- if (auto* s = ty->As<sem::Struct>()) {
- align = s->Align();
- size = s->Size();
- return true;
- }
- if (auto* a = ty->As<sem::Array>()) {
- align = a->Align();
- size = a->SizeInBytes();
- return true;
- }
- if (auto* a = ty->As<sem::Atomic>()) {
- return DefaultAlignAndSize(a->Type(), align, size);
- }
- TINT_UNREACHABLE(Resolver, diagnostics_)
- << "invalid type " << ty->TypeInfo().name;
- return false;
-}
-
sem::Array* Resolver::Array(const ast::Array* arr) {
auto source = arr->source();
@@ -3821,7 +3758,7 @@
return nullptr;
}
- if (!IsPlain(el_ty)) { // Check must come before DefaultAlignAndSize()
+ if (!IsPlain(el_ty)) { // Check must come before GetDefaultAlignAndSize()
AddError(el_ty->FriendlyName(builder_->Symbols()) +
" cannot be used as an element type of an array",
source);
@@ -3830,9 +3767,7 @@
uint32_t el_align = 0;
uint32_t el_size = 0;
- if (!DefaultAlignAndSize(el_ty, el_align, el_size)) {
- return nullptr;
- }
+ el_ty->GetDefaultAlignAndSize(el_align, el_size);
if (!ValidateNoDuplicateDecorations(arr->decorations())) {
return nullptr;
@@ -4040,9 +3975,7 @@
uint32_t offset = struct_size;
uint32_t align = 0;
uint32_t size = 0;
- if (!DefaultAlignAndSize(type, align, size)) {
- return nullptr;
- }
+ type->GetDefaultAlignAndSize(align, size);
if (!ValidateNoDuplicateDecorations(member->decorations())) {
return nullptr;
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index ce82686..ad7e56d 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -375,13 +375,6 @@
sem::Type* ty,
const Source& usage);
- /// @param align the output default alignment in bytes for the type `ty`
- /// @param size the output default size in bytes for the type `ty`
- /// @returns true on success, false on error
- bool DefaultAlignAndSize(const sem::Type* ty,
- uint32_t& align,
- uint32_t& size);
-
/// @param storage_class the storage class
/// @returns the default access control for the given storage class
ast::Access DefaultAccessForStorageClass(ast::StorageClass storage_class);
diff --git a/src/sem/type.cc b/src/sem/type.cc
index 97a9b49..bddc470 100644
--- a/src/sem/type.cc
+++ b/src/sem/type.cc
@@ -14,6 +14,9 @@
#include "src/sem/type.h"
+#include "src/debug.h"
+#include "src/sem/array.h"
+#include "src/sem/atomic_type.h"
#include "src/sem/bool_type.h"
#include "src/sem/f32_type.h"
#include "src/sem/i32_type.h"
@@ -21,6 +24,7 @@
#include "src/sem/pointer_type.h"
#include "src/sem/reference_type.h"
#include "src/sem/sampler_type.h"
+#include "src/sem/struct.h"
#include "src/sem/texture_type.h"
#include "src/sem/u32_type.h"
#include "src/sem/vector_type.h"
@@ -52,6 +56,61 @@
return type;
}
+void Type::GetDefaultAlignAndSize(uint32_t& align, uint32_t& size) const {
+ TINT_ASSERT(Semantic, !As<Reference>());
+ TINT_ASSERT(Semantic, !As<Pointer>());
+
+ static constexpr uint32_t vector_size[] = {
+ /* padding */ 0,
+ /* padding */ 0,
+ /*vec2*/ 8,
+ /*vec3*/ 12,
+ /*vec4*/ 16,
+ };
+ static constexpr uint32_t vector_align[] = {
+ /* padding */ 0,
+ /* padding */ 0,
+ /*vec2*/ 8,
+ /*vec3*/ 16,
+ /*vec4*/ 16,
+ };
+
+ if (is_scalar()) {
+ // Note: Also captures booleans, but these are not host-shareable.
+ align = 4;
+ size = 4;
+ return;
+ }
+ if (auto* vec = As<Vector>()) {
+ TINT_ASSERT(Semantic, vec->size() >= 2 && vec->size() <= 4);
+ align = vector_align[vec->size()];
+ size = vector_size[vec->size()];
+ return;
+ }
+ if (auto* mat = As<Matrix>()) {
+ TINT_ASSERT(Semantic, mat->columns() >= 2 && mat->columns() <= 4);
+ TINT_ASSERT(Semantic, mat->rows() >= 2 && mat->rows() <= 4);
+ align = vector_align[mat->rows()];
+ size = vector_align[mat->rows()] * mat->columns();
+ return;
+ }
+ if (auto* s = As<Struct>()) {
+ align = s->Align();
+ size = s->Size();
+ return;
+ }
+ if (auto* a = As<Array>()) {
+ align = a->Align();
+ size = a->SizeInBytes();
+ return;
+ }
+ if (auto* a = As<Atomic>()) {
+ return a->Type()->GetDefaultAlignAndSize(align, size);
+ }
+
+ TINT_ASSERT(Semantic, false);
+}
+
bool Type::is_scalar() const {
return IsAnyOf<F32, U32, I32, Bool>();
}
diff --git a/src/sem/type.h b/src/sem/type.h
index f48fa5f..3a268d4 100644
--- a/src/sem/type.h
+++ b/src/sem/type.h
@@ -52,6 +52,10 @@
/// @returns the inner type if this is a reference, `this` otherwise
const Type* UnwrapRef() const;
+ /// @param align the output default alignment in bytes for this type.
+ /// @param size the output default size in bytes for this type.
+ void GetDefaultAlignAndSize(uint32_t& align, uint32_t& size) const;
+
/// @returns true if this type is a scalar
bool is_scalar() const;
/// @returns true if this type is a numeric scalar