inspector: reflect workgroup storage size

This reflects the total size of all workgroup storage-class variables
referenced transitively by an entry point.

Bug: tint:919
Change-Id: If3a217fea5a875ac18db6de1579f004e368fbb7b
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/57740
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ken Rockot <rockot@google.com>
diff --git a/src/inspector/inspector.cc b/src/inspector/inspector.cc
index f8b3b59..07e942a 100644
--- a/src/inspector/inspector.cc
+++ b/src/inspector/inspector.cc
@@ -42,6 +42,7 @@
 #include "src/sem/variable.h"
 #include "src/sem/vector_type.h"
 #include "src/sem/void_type.h"
+#include "src/utils/math.h"
 
 namespace tint {
 namespace inspector {
@@ -534,6 +535,31 @@
   return it->second;
 }
 
+uint32_t Inspector::GetWorkgroupStorageSize(const std::string& entry_point) {
+  auto* func = FindEntryPointByName(entry_point);
+  if (!func) {
+    return 0;
+  }
+
+  uint32_t total_size = 0;
+  auto* func_sem = program_->Sem().Get(func);
+  for (const sem::Variable* var : func_sem->ReferencedModuleVariables()) {
+    if (var->StorageClass() == ast::StorageClass::kWorkgroup) {
+      uint32_t align = 0;
+      uint32_t size = 0;
+      var->Type()->UnwrapRef()->GetDefaultAlignAndSize(align, size);
+
+      // This essentially matches std430 layout rules from GLSL, which are in
+      // turn specified as an upper bound for Vulkan layout sizing. Since D3D
+      // and Metal are even less specific, we assume Vulkan behavior as a
+      // good-enough approximation everywhere.
+      total_size += utils::RoundUp(align, size);
+    }
+  }
+
+  return total_size;
+}
+
 ast::Function* Inspector::FindEntryPointByName(const std::string& name) {
   auto* func = program_->AST().Functions().Find(program_->Symbols().Get(name));
   if (!func) {
diff --git a/src/inspector/inspector.h b/src/inspector/inspector.h
index b7228e3..a386a10 100644
--- a/src/inspector/inspector.h
+++ b/src/inspector/inspector.h
@@ -132,6 +132,11 @@
   std::vector<SamplerTexturePair> GetSamplerTextureUses(
       const std::string& entry_point);
 
+  /// @param entry_point name of the entry point to get information about.
+  /// @returns the total size in bytes of all Workgroup storage-class storage
+  /// referenced transitively by the entry point.
+  uint32_t GetWorkgroupStorageSize(const std::string& entry_point);
+
  private:
   const Program* program_;
   std::string error_;
diff --git a/src/inspector/inspector_test.cc b/src/inspector/inspector_test.cc
index 99e5917..7d7740a 100644
--- a/src/inspector/inspector_test.cc
+++ b/src/inspector/inspector_test.cc
@@ -134,6 +134,9 @@
 class InspectorGetSamplerTextureUsesTest : public InspectorBuilder,
                                            public testing::Test {};
 
+class InspectorGetWorkgroupStorageSizeTest : public InspectorBuilder,
+                                             public testing::Test {};
+
 TEST_F(InspectorGetEntryPointTest, NoFunctions) {
   Inspector& inspector = Build();
 
@@ -549,7 +552,7 @@
 
 TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByEntryPoint) {
   AddOverridableConstantWithoutID<float>("foo", ty.f32(), nullptr);
-  MakeConstReferenceBodyFunction(
+  MakePlainGlobalReferenceBodyFunction(
       "ep_func", "foo", ty.f32(),
       {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
 
@@ -564,7 +567,7 @@
 
 TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByCallee) {
   AddOverridableConstantWithoutID<float>("foo", ty.f32(), nullptr);
-  MakeConstReferenceBodyFunction("callee_func", "foo", ty.f32(), {});
+  MakePlainGlobalReferenceBodyFunction("callee_func", "foo", ty.f32(), {});
   MakeCallerBodyFunction(
       "ep_func", {"callee_func"},
       {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
@@ -581,7 +584,7 @@
 TEST_F(InspectorGetEntryPointTest, OverridableConstantSomeReferenced) {
   AddOverridableConstantWithID<float>("foo", 1, ty.f32(), nullptr);
   AddOverridableConstantWithID<float>("bar", 2, ty.f32(), nullptr);
-  MakeConstReferenceBodyFunction("callee_func", "foo", ty.f32(), {});
+  MakePlainGlobalReferenceBodyFunction("callee_func", "foo", ty.f32(), {});
   MakeCallerBodyFunction(
       "ep_func", {"callee_func"},
       {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
@@ -2397,6 +2400,93 @@
   EXPECT_EQ(0u, result[0].texture_binding_point.binding);
 }
 
+TEST_F(InspectorGetWorkgroupStorageSizeTest, Empty) {
+  MakeEmptyBodyFunction("ep_func",
+                        ast::DecorationList{Stage(ast::PipelineStage::kCompute),
+                                            WorkgroupSize(1)});
+  Inspector& inspector = Build();
+  EXPECT_EQ(0u, inspector.GetWorkgroupStorageSize("ep_func"));
+}
+
+TEST_F(InspectorGetWorkgroupStorageSizeTest, Simple) {
+  AddWorkgroupStorage("wg_f32", ty.f32());
+  MakePlainGlobalReferenceBodyFunction("f32_func", "wg_f32", ty.f32(), {});
+
+  MakeCallerBodyFunction("ep_func", {"f32_func"},
+                         ast::DecorationList{
+                             Stage(ast::PipelineStage::kCompute),
+                             WorkgroupSize(1),
+                         });
+
+  Inspector& inspector = Build();
+  EXPECT_EQ(4u, inspector.GetWorkgroupStorageSize("ep_func"));
+}
+
+TEST_F(InspectorGetWorkgroupStorageSizeTest, CompoundTypes) {
+  // This struct should occupy 68 bytes. 4 from the i32 field, and another 64
+  // from the 4-element array with 16-byte stride.
+  ast::Struct* wg_struct_type = MakeStructType(
+      "WgStruct", {ty.i32(), ty.array(ty.i32(), 4, /*stride=*/16)},
+      /*is_block=*/false);
+  AddWorkgroupStorage("wg_struct_var", ty.Of(wg_struct_type));
+  MakeStructVariableReferenceBodyFunction("wg_struct_func", "wg_struct_var",
+                                          {{0, ty.i32()}});
+
+  // Plus another 4 bytes from this other workgroup-class f32.
+  AddWorkgroupStorage("wg_f32", ty.f32());
+  MakePlainGlobalReferenceBodyFunction("f32_func", "wg_f32", ty.f32(), {});
+
+  MakeCallerBodyFunction("ep_func", {"wg_struct_func", "f32_func"},
+                         ast::DecorationList{
+                             Stage(ast::PipelineStage::kCompute),
+                             WorkgroupSize(1),
+                         });
+
+  Inspector& inspector = Build();
+  EXPECT_EQ(72u, inspector.GetWorkgroupStorageSize("ep_func"));
+}
+
+TEST_F(InspectorGetWorkgroupStorageSizeTest, AlignmentPadding) {
+  // vec3<f32> has an alignment of 16 but a size of 12. We leverage this to test
+  // that our padded size calculation for workgroup storage is accurate.
+  AddWorkgroupStorage("wg_vec3", ty.vec3<f32>());
+  MakePlainGlobalReferenceBodyFunction("wg_func", "wg_vec3", ty.vec3<f32>(),
+                                       {});
+
+  MakeCallerBodyFunction("ep_func", {"wg_func"},
+                         ast::DecorationList{
+                             Stage(ast::PipelineStage::kCompute),
+                             WorkgroupSize(1),
+                         });
+
+  Inspector& inspector = Build();
+  EXPECT_EQ(16u, inspector.GetWorkgroupStorageSize("ep_func"));
+}
+
+TEST_F(InspectorGetWorkgroupStorageSizeTest, StructAlignment) {
+  // Per WGSL spec, a struct's size is the offset its last member plus the size
+  // of its last member, rounded up to the alignment of its largest member. So
+  // here the struct is expected to occupy 1024 bytes of workgroup storage.
+  ast::Struct* wg_struct_type = MakeStructTypeFromMembers(
+      "WgStruct",
+      {MakeStructMember(0, ty.f32(),
+                        {create<ast::StructMemberAlignDecoration>(1024)})},
+      /*is_block=*/false);
+
+  AddWorkgroupStorage("wg_struct_var", ty.Of(wg_struct_type));
+  MakeStructVariableReferenceBodyFunction("wg_struct_func", "wg_struct_var",
+                                          {{0, ty.f32()}});
+
+  MakeCallerBodyFunction("ep_func", {"wg_struct_func"},
+                         ast::DecorationList{
+                             Stage(ast::PipelineStage::kCompute),
+                             WorkgroupSize(1),
+                         });
+
+  Inspector& inspector = Build();
+  EXPECT_EQ(1024u, inspector.GetWorkgroupStorageSize("ep_func"));
+}
+
 }  // namespace
 }  // namespace inspector
 }  // namespace tint
diff --git a/src/inspector/test_inspector_builder.cc b/src/inspector/test_inspector_builder.cc
index 922f342..5667acb 100644
--- a/src/inspector/test_inspector_builder.cc
+++ b/src/inspector/test_inspector_builder.cc
@@ -60,7 +60,7 @@
   return Structure(name, members);
 }
 
-ast::Function* InspectorBuilder::MakeConstReferenceBodyFunction(
+ast::Function* InspectorBuilder::MakePlainGlobalReferenceBodyFunction(
     std::string func,
     std::string var,
     ast::Type* type,
@@ -93,15 +93,27 @@
     bool is_block) {
   ast::StructMemberList members;
   for (auto* type : member_types) {
-    members.push_back(Member(StructMemberName(members.size(), type), type));
+    members.push_back(MakeStructMember(members.size(), type, {}));
   }
+  return MakeStructTypeFromMembers(name, std::move(members), is_block);
+}
 
+ast::Struct* InspectorBuilder::MakeStructTypeFromMembers(
+    const std::string& name,
+    ast::StructMemberList members,
+    bool is_block) {
   ast::DecorationList decos;
   if (is_block) {
     decos.push_back(create<ast::StructBlockDecoration>());
   }
+  return Structure(name, std::move(members), decos);
+}
 
-  return Structure(name, members, decos);
+ast::StructMember* InspectorBuilder::MakeStructMember(
+    size_t index,
+    ast::Type* type,
+    ast::DecorationList decorations) {
+  return Member(StructMemberName(index, type), type, std::move(decorations));
 }
 
 ast::Struct* InspectorBuilder::MakeUniformBufferType(
@@ -128,6 +140,11 @@
          });
 }
 
+void InspectorBuilder::AddWorkgroupStorage(const std::string& name,
+                                           ast::Type* type) {
+  Global(name, type, ast::StorageClass::kWorkgroup);
+}
+
 void InspectorBuilder::AddStorageBuffer(const std::string& name,
                                         ast::Type* type,
                                         ast::Access access,
diff --git a/src/inspector/test_inspector_builder.h b/src/inspector/test_inspector_builder.h
index 3d1353d..7697601 100644
--- a/src/inspector/test_inspector_builder.h
+++ b/src/inspector/test_inspector_builder.h
@@ -139,13 +139,14 @@
                        });
   }
 
-  /// Generates a function that references module constant
+  /// Generates a function that references module-scoped, plain-typed constant
+  /// or variable.
   /// @param func name of the function created
   /// @param var name of the constant to be reference
   /// @param type type of the const being referenced
   /// @param decorations the function decorations
   /// @returns a function object
-  ast::Function* MakeConstReferenceBodyFunction(
+  ast::Function* MakePlainGlobalReferenceBodyFunction(
       std::string func,
       std::string var,
       ast::Type* type,
@@ -172,6 +173,24 @@
                               std::vector<ast::Type*> member_types,
                               bool is_block);
 
+  /// Generates a struct type from a list of member nodes.
+  /// @param name name for the struct type
+  /// @param members a vector of members
+  /// @param is_block whether or not to decorate as a Block
+  /// @returns a struct type
+  ast::Struct* MakeStructTypeFromMembers(const std::string& name,
+                                         ast::StructMemberList members,
+                                         bool is_block);
+
+  /// Generates a struct member with a specified index and type.
+  /// @param index index of the field within the struct
+  /// @param type the type of the member field
+  /// @param decorations a list of decorations to apply to the member field
+  /// @returns a struct member
+  ast::StructMember* MakeStructMember(size_t index,
+                                      ast::Type* type,
+                                      ast::DecorationList decorations);
+
   /// Generates types appropriate for using in an uniform buffer
   /// @param name name for the type
   /// @param member_types a vector of member types
@@ -197,6 +216,11 @@
                         uint32_t group,
                         uint32_t binding);
 
+  /// Adds a workgroup storage variable to the program
+  /// @param name the name of the variable
+  /// @param type the type of the variable
+  void AddWorkgroupStorage(const std::string& name, ast::Type* type);
+
   /// Adds a storage buffer variable to the program
   /// @param name the name of the variable
   /// @param type the type to use
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 20c2739..492026c 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -723,7 +723,7 @@
   auto required_alignment_of = [&](const sem::Type* ty) {
     uint32_t actual_align = 0;
     uint32_t actual_size = 0;
-    DefaultAlignAndSize(ty, actual_align, actual_size);
+    ty->GetDefaultAlignAndSize(actual_align, actual_size);
     uint32_t required_align = actual_align;
     if (is_uniform_struct_or_array(ty)) {
       required_align = utils::RoundUp(16u, actual_align);
@@ -3750,69 +3750,6 @@
   }
 }
 
-bool Resolver::DefaultAlignAndSize(const sem::Type* ty,
-                                   uint32_t& align,
-                                   uint32_t& size) {
-  static constexpr uint32_t vector_size[] = {
-      /* padding */ 0,
-      /* padding */ 0,
-      /*vec2*/ 8,
-      /*vec3*/ 12,
-      /*vec4*/ 16,
-  };
-  static constexpr uint32_t vector_align[] = {
-      /* padding */ 0,
-      /* padding */ 0,
-      /*vec2*/ 8,
-      /*vec3*/ 16,
-      /*vec4*/ 16,
-  };
-
-  if (ty->is_scalar()) {
-    // Note: Also captures booleans, but these are not host-shareable.
-    align = 4;
-    size = 4;
-    return true;
-  }
-  if (auto* vec = ty->As<sem::Vector>()) {
-    if (vec->size() < 2 || vec->size() > 4) {
-      TINT_UNREACHABLE(Resolver, diagnostics_)
-          << "Invalid vector size: vec" << vec->size();
-      return false;
-    }
-    align = vector_align[vec->size()];
-    size = vector_size[vec->size()];
-    return true;
-  }
-  if (auto* mat = ty->As<sem::Matrix>()) {
-    if (mat->columns() < 2 || mat->columns() > 4 || mat->rows() < 2 ||
-        mat->rows() > 4) {
-      TINT_UNREACHABLE(Resolver, diagnostics_)
-          << "Invalid matrix size: mat" << mat->columns() << "x" << mat->rows();
-      return false;
-    }
-    align = vector_align[mat->rows()];
-    size = vector_align[mat->rows()] * mat->columns();
-    return true;
-  }
-  if (auto* s = ty->As<sem::Struct>()) {
-    align = s->Align();
-    size = s->Size();
-    return true;
-  }
-  if (auto* a = ty->As<sem::Array>()) {
-    align = a->Align();
-    size = a->SizeInBytes();
-    return true;
-  }
-  if (auto* a = ty->As<sem::Atomic>()) {
-    return DefaultAlignAndSize(a->Type(), align, size);
-  }
-  TINT_UNREACHABLE(Resolver, diagnostics_)
-      << "invalid type " << ty->TypeInfo().name;
-  return false;
-}
-
 sem::Array* Resolver::Array(const ast::Array* arr) {
   auto source = arr->source();
 
@@ -3821,7 +3758,7 @@
     return nullptr;
   }
 
-  if (!IsPlain(el_ty)) {  // Check must come before DefaultAlignAndSize()
+  if (!IsPlain(el_ty)) {  // Check must come before GetDefaultAlignAndSize()
     AddError(el_ty->FriendlyName(builder_->Symbols()) +
                  " cannot be used as an element type of an array",
              source);
@@ -3830,9 +3767,7 @@
 
   uint32_t el_align = 0;
   uint32_t el_size = 0;
-  if (!DefaultAlignAndSize(el_ty, el_align, el_size)) {
-    return nullptr;
-  }
+  el_ty->GetDefaultAlignAndSize(el_align, el_size);
 
   if (!ValidateNoDuplicateDecorations(arr->decorations())) {
     return nullptr;
@@ -4040,9 +3975,7 @@
     uint32_t offset = struct_size;
     uint32_t align = 0;
     uint32_t size = 0;
-    if (!DefaultAlignAndSize(type, align, size)) {
-      return nullptr;
-    }
+    type->GetDefaultAlignAndSize(align, size);
 
     if (!ValidateNoDuplicateDecorations(member->decorations())) {
       return nullptr;
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index ce82686..ad7e56d 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -375,13 +375,6 @@
                                     sem::Type* ty,
                                     const Source& usage);
 
-  /// @param align the output default alignment in bytes for the type `ty`
-  /// @param size the output default size in bytes for the type `ty`
-  /// @returns true on success, false on error
-  bool DefaultAlignAndSize(const sem::Type* ty,
-                           uint32_t& align,
-                           uint32_t& size);
-
   /// @param storage_class the storage class
   /// @returns the default access control for the given storage class
   ast::Access DefaultAccessForStorageClass(ast::StorageClass storage_class);
diff --git a/src/sem/type.cc b/src/sem/type.cc
index 97a9b49..bddc470 100644
--- a/src/sem/type.cc
+++ b/src/sem/type.cc
@@ -14,6 +14,9 @@
 
 #include "src/sem/type.h"
 
+#include "src/debug.h"
+#include "src/sem/array.h"
+#include "src/sem/atomic_type.h"
 #include "src/sem/bool_type.h"
 #include "src/sem/f32_type.h"
 #include "src/sem/i32_type.h"
@@ -21,6 +24,7 @@
 #include "src/sem/pointer_type.h"
 #include "src/sem/reference_type.h"
 #include "src/sem/sampler_type.h"
+#include "src/sem/struct.h"
 #include "src/sem/texture_type.h"
 #include "src/sem/u32_type.h"
 #include "src/sem/vector_type.h"
@@ -52,6 +56,61 @@
   return type;
 }
 
+void Type::GetDefaultAlignAndSize(uint32_t& align, uint32_t& size) const {
+  TINT_ASSERT(Semantic, !As<Reference>());
+  TINT_ASSERT(Semantic, !As<Pointer>());
+
+  static constexpr uint32_t vector_size[] = {
+      /* padding */ 0,
+      /* padding */ 0,
+      /*vec2*/ 8,
+      /*vec3*/ 12,
+      /*vec4*/ 16,
+  };
+  static constexpr uint32_t vector_align[] = {
+      /* padding */ 0,
+      /* padding */ 0,
+      /*vec2*/ 8,
+      /*vec3*/ 16,
+      /*vec4*/ 16,
+  };
+
+  if (is_scalar()) {
+    // Note: Also captures booleans, but these are not host-shareable.
+    align = 4;
+    size = 4;
+    return;
+  }
+  if (auto* vec = As<Vector>()) {
+    TINT_ASSERT(Semantic, vec->size() >= 2 && vec->size() <= 4);
+    align = vector_align[vec->size()];
+    size = vector_size[vec->size()];
+    return;
+  }
+  if (auto* mat = As<Matrix>()) {
+    TINT_ASSERT(Semantic, mat->columns() >= 2 && mat->columns() <= 4);
+    TINT_ASSERT(Semantic, mat->rows() >= 2 && mat->rows() <= 4);
+    align = vector_align[mat->rows()];
+    size = vector_align[mat->rows()] * mat->columns();
+    return;
+  }
+  if (auto* s = As<Struct>()) {
+    align = s->Align();
+    size = s->Size();
+    return;
+  }
+  if (auto* a = As<Array>()) {
+    align = a->Align();
+    size = a->SizeInBytes();
+    return;
+  }
+  if (auto* a = As<Atomic>()) {
+    return a->Type()->GetDefaultAlignAndSize(align, size);
+  }
+
+  TINT_ASSERT(Semantic, false);
+}
+
 bool Type::is_scalar() const {
   return IsAnyOf<F32, U32, I32, Bool>();
 }
diff --git a/src/sem/type.h b/src/sem/type.h
index f48fa5f..3a268d4 100644
--- a/src/sem/type.h
+++ b/src/sem/type.h
@@ -52,6 +52,10 @@
   /// @returns the inner type if this is a reference, `this` otherwise
   const Type* UnwrapRef() const;
 
+  /// @param align the output default alignment in bytes for this type.
+  /// @param size the output default size in bytes for this type.
+  void GetDefaultAlignAndSize(uint32_t& align, uint32_t& size) const;
+
   /// @returns true if this type is a scalar
   bool is_scalar() const;
   /// @returns true if this type is a numeric scalar