Update MSL Generator to return workgroup information.

This CL adds an output to the MSL generator to return the workgroup
size information after `SubstituteOverrides` has been run.

Bug: 380043961
Change-Id: I75dad66bed3afeabdc18756f75b1e109b2a5344b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/218856
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/dawn/native/metal/ShaderModuleMTL.mm b/src/dawn/native/metal/ShaderModuleMTL.mm
index bd5ead5..29d4380 100644
--- a/src/dawn/native/metal/ShaderModuleMTL.mm
+++ b/src/dawn/native/metal/ShaderModuleMTL.mm
@@ -375,7 +375,7 @@
             )" + msl;
 
             auto workgroupAllocations =
-                std::move(result->workgroup_allocations.at(kRemappedEntryPointName));
+                std::move(result->workgroup_info.allocations.at(kRemappedEntryPointName));
             return MslCompilation{{
                 std::move(msl),
                 std::move(kRemappedEntryPointName),
diff --git a/src/tint/lang/msl/writer/BUILD.bazel b/src/tint/lang/msl/writer/BUILD.bazel
index 766902e..489aadd 100644
--- a/src/tint/lang/msl/writer/BUILD.bazel
+++ b/src/tint/lang/msl/writer/BUILD.bazel
@@ -39,11 +39,9 @@
 cc_library(
   name = "writer",
   srcs = [
-    "output.cc",
     "writer.cc",
   ],
   hdrs = [
-    "output.h",
     "writer.h",
   ],
   deps = [
diff --git a/src/tint/lang/msl/writer/BUILD.cmake b/src/tint/lang/msl/writer/BUILD.cmake
index fe540d4..536378a 100644
--- a/src/tint/lang/msl/writer/BUILD.cmake
+++ b/src/tint/lang/msl/writer/BUILD.cmake
@@ -48,8 +48,6 @@
 # Condition: TINT_BUILD_MSL_WRITER
 ################################################################################
 tint_add_target(tint_lang_msl_writer lib
-  lang/msl/writer/output.cc
-  lang/msl/writer/output.h
   lang/msl/writer/writer.cc
   lang/msl/writer/writer.h
 )
diff --git a/src/tint/lang/msl/writer/BUILD.gn b/src/tint/lang/msl/writer/BUILD.gn
index 118d78a..9a2501bd 100644
--- a/src/tint/lang/msl/writer/BUILD.gn
+++ b/src/tint/lang/msl/writer/BUILD.gn
@@ -45,8 +45,6 @@
 if (tint_build_msl_writer) {
   libtint_source_set("writer") {
     sources = [
-      "output.cc",
-      "output.h",
       "writer.cc",
       "writer.h",
     ]
diff --git a/src/tint/lang/msl/writer/common/BUILD.bazel b/src/tint/lang/msl/writer/common/BUILD.bazel
index 282ca55..89e8002 100644
--- a/src/tint/lang/msl/writer/common/BUILD.bazel
+++ b/src/tint/lang/msl/writer/common/BUILD.bazel
@@ -41,11 +41,13 @@
   srcs = [
     "option_helpers.cc",
     "options.cc",
+    "output.cc",
     "printer_support.cc",
   ],
   hdrs = [
     "option_helpers.h",
     "options.h",
+    "output.h",
     "printer_support.h",
   ],
   deps = [
diff --git a/src/tint/lang/msl/writer/common/BUILD.cmake b/src/tint/lang/msl/writer/common/BUILD.cmake
index fd56a5e..bd173cc 100644
--- a/src/tint/lang/msl/writer/common/BUILD.cmake
+++ b/src/tint/lang/msl/writer/common/BUILD.cmake
@@ -45,6 +45,8 @@
   lang/msl/writer/common/option_helpers.h
   lang/msl/writer/common/options.cc
   lang/msl/writer/common/options.h
+  lang/msl/writer/common/output.cc
+  lang/msl/writer/common/output.h
   lang/msl/writer/common/printer_support.cc
   lang/msl/writer/common/printer_support.h
 )
diff --git a/src/tint/lang/msl/writer/common/BUILD.gn b/src/tint/lang/msl/writer/common/BUILD.gn
index 8567732..a1fb165 100644
--- a/src/tint/lang/msl/writer/common/BUILD.gn
+++ b/src/tint/lang/msl/writer/common/BUILD.gn
@@ -49,6 +49,8 @@
       "option_helpers.h",
       "options.cc",
       "options.h",
+      "output.cc",
+      "output.h",
       "printer_support.cc",
       "printer_support.h",
     ]
diff --git a/src/tint/lang/msl/writer/output.cc b/src/tint/lang/msl/writer/common/output.cc
similarity index 96%
rename from src/tint/lang/msl/writer/output.cc
rename to src/tint/lang/msl/writer/common/output.cc
index cc2f3a0..14cf191 100644
--- a/src/tint/lang/msl/writer/output.cc
+++ b/src/tint/lang/msl/writer/common/output.cc
@@ -25,7 +25,7 @@
 // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
-#include "src/tint/lang/msl/writer/output.h"
+#include "src/tint/lang/msl/writer/common/output.h"
 
 namespace tint::msl::writer {
 
diff --git a/src/tint/lang/msl/writer/output.h b/src/tint/lang/msl/writer/common/output.h
similarity index 74%
rename from src/tint/lang/msl/writer/output.h
rename to src/tint/lang/msl/writer/common/output.h
index edcd2b2d..0d933df 100644
--- a/src/tint/lang/msl/writer/output.h
+++ b/src/tint/lang/msl/writer/common/output.h
@@ -25,13 +25,12 @@
 // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
-#ifndef SRC_TINT_LANG_MSL_WRITER_OUTPUT_H_
-#define SRC_TINT_LANG_MSL_WRITER_OUTPUT_H_
+#ifndef SRC_TINT_LANG_MSL_WRITER_COMMON_OUTPUT_H_
+#define SRC_TINT_LANG_MSL_WRITER_COMMON_OUTPUT_H_
 
 #include <cstdint>
 #include <string>
 #include <unordered_map>
-#include <unordered_set>
 #include <vector>
 
 namespace tint::msl::writer {
@@ -51,6 +50,21 @@
     /// @returns this
     Output& operator=(const Output&);
 
+    /// Workgroup size information
+    struct WorkgroupInfo {
+        /// The x-component
+        uint32_t x = 0;
+        /// The y-component
+        uint32_t y = 0;
+        /// The z-component
+        uint32_t z = 0;
+
+        /// A map from entry point name to a list of dynamic workgroup allocations.
+        /// Each entry in the vector is the size of the workgroup allocation that
+        /// should be created for that index.
+        std::unordered_map<std::string, std::vector<uint32_t>> allocations;
+    };
+
     /// The generated MSL.
     std::string msl = "";
 
@@ -60,12 +74,10 @@
     /// True if the generated shader uses the invariant attribute.
     bool has_invariant_attribute = false;
 
-    /// A map from entry point name to a list of dynamic workgroup allocations.
-    /// Each entry in the vector is the size of the workgroup allocation that
-    /// should be created for that index.
-    std::unordered_map<std::string, std::vector<uint32_t>> workgroup_allocations;
+    /// The workgroup size information, if the entry point was a compute shader
+    WorkgroupInfo workgroup_info{};
 };
 
 }  // namespace tint::msl::writer
 
-#endif  // SRC_TINT_LANG_MSL_WRITER_OUTPUT_H_
+#endif  // SRC_TINT_LANG_MSL_WRITER_COMMON_OUTPUT_H_
diff --git a/src/tint/lang/msl/writer/function_test.cc b/src/tint/lang/msl/writer/function_test.cc
index d223271..96c6d9f 100644
--- a/src/tint/lang/msl/writer/function_test.cc
+++ b/src/tint/lang/msl/writer/function_test.cc
@@ -28,7 +28,8 @@
 #include "src/tint/lang/core/type/sampled_texture.h"
 #include "src/tint/lang/msl/writer/helper_test.h"
 
-using namespace tint::core::fluent_types;  // NOLINT
+using namespace tint::core::fluent_types;     // NOLINT
+using namespace tint::core::number_suffixes;  // NOLINT
 
 namespace tint::msl::writer {
 namespace {
@@ -42,6 +43,28 @@
 void foo() {
 }
 )");
+
+    // MSL doesn't inject an empty entry point, so in this case there is no result.
+    EXPECT_EQ(output_.workgroup_info.x, 0u);
+    EXPECT_EQ(output_.workgroup_info.y, 0u);
+    EXPECT_EQ(output_.workgroup_info.z, 0u);
+}
+
+TEST_F(MslWriterTest, Function_EntryPoint_Compute) {
+    auto* func = b.ComputeFunction("cmp_main", 32_u, 4_u, 1_u);
+    b.Append(func->Block(), [&] {  //
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(Generate()) << err_ << output_.msl;
+    EXPECT_EQ(output_.msl, MetalHeader() + R"(
+kernel void cmp_main() {
+}
+)");
+
+    EXPECT_EQ(output_.workgroup_info.x, 32u);
+    EXPECT_EQ(output_.workgroup_info.y, 4u);
+    EXPECT_EQ(output_.workgroup_info.z, 1u);
 }
 
 TEST_F(MslWriterTest, EntryPointParameterBufferBindingPoint) {
@@ -70,6 +93,9 @@
   tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.storage_var=storage_var, .uniform_var=uniform_var};
 }
 )");
+    EXPECT_EQ(output_.workgroup_info.x, 0u);
+    EXPECT_EQ(output_.workgroup_info.y, 0u);
+    EXPECT_EQ(output_.workgroup_info.z, 0u);
 }
 
 TEST_F(MslWriterTest, EntryPointParameterHandleBindingPoint) {
diff --git a/src/tint/lang/msl/writer/printer/printer.cc b/src/tint/lang/msl/writer/printer/printer.cc
index ab1dd56..a842976 100644
--- a/src/tint/lang/msl/writer/printer/printer.cc
+++ b/src/tint/lang/msl/writer/printer/printer.cc
@@ -119,7 +119,7 @@
         : ir_(module), options_(options) {}
 
     /// @returns the generated MSL shader
-    tint::Result<PrintResult> Generate() {
+    tint::Result<Output> Generate() {
         auto valid = core::ir::ValidateAndDumpIfNeeded(
             ir_, "msl.Printer",
             core::ir::Capabilities{
@@ -157,7 +157,7 @@
 
   private:
     /// The result of printing the module.
-    PrintResult result_;
+    Output result_;
 
     /// Map of builtin structure to unique generated name
     Hashmap<const core::type::Struct*, std::string, 4> builtin_struct_names_;
@@ -319,9 +319,20 @@
             auto func_name = NameOf(func);
 
             switch (func->Stage()) {
-                case core::ir::Function::PipelineStage::kCompute:
+                case core::ir::Function::PipelineStage::kCompute: {
                     out << "kernel ";
+
+                    auto const_wg_size = func->WorkgroupSizeAsConst();
+                    TINT_ASSERT(const_wg_size);
+                    auto wg_size = *const_wg_size;
+
+                    // Store the workgroup information away to return from the generator.
+                    result_.workgroup_info.x = wg_size[0];
+                    result_.workgroup_info.y = wg_size[1];
+                    result_.workgroup_info.z = wg_size[2];
+
                     break;
+                }
                 case core::ir::Function::PipelineStage::kFragment:
                     out << "fragment ";
                     break;
@@ -332,7 +343,7 @@
                     break;
             }
             if (func->Stage() != core::ir::Function::PipelineStage::kUndefined) {
-                result_.workgroup_allocations.insert({func_name, {}});
+                result_.workgroup_info.allocations.insert({func_name, {}});
             }
 
             EmitType(out, func->ReturnType());
@@ -395,7 +406,7 @@
                 }
                 if (ptr && ptr->AddressSpace() == core::AddressSpace::kWorkgroup &&
                     func->Stage() == core::ir::Function::PipelineStage::kCompute) {
-                    auto& allocations = result_.workgroup_allocations.at(func_name);
+                    auto& allocations = result_.workgroup_info.allocations.at(func_name);
                     out << " [[threadgroup(" << allocations.size() << ")]]";
                     allocations.push_back(ptr->StoreType()->Size());
                 }
@@ -1765,16 +1776,8 @@
 
 }  // namespace
 
-Result<PrintResult> Print(core::ir::Module& module, const Options& options) {
+Result<Output> Print(core::ir::Module& module, const Options& options) {
     return Printer{module, options}.Generate();
 }
 
-PrintResult::PrintResult() = default;
-
-PrintResult::~PrintResult() = default;
-
-PrintResult::PrintResult(const PrintResult&) = default;
-
-PrintResult& PrintResult::operator=(const PrintResult&) = default;
-
 }  // namespace tint::msl::writer
diff --git a/src/tint/lang/msl/writer/printer/printer.h b/src/tint/lang/msl/writer/printer/printer.h
index 18d49fd..4cbdbd5 100644
--- a/src/tint/lang/msl/writer/printer/printer.h
+++ b/src/tint/lang/msl/writer/printer/printer.h
@@ -33,6 +33,7 @@
 #include <vector>
 
 #include "src/tint/lang/msl/writer/common/options.h"
+#include "src/tint/lang/msl/writer/common/output.h"
 #include "src/tint/utils/result/result.h"
 
 // Forward declarations
@@ -42,36 +43,9 @@
 
 namespace tint::msl::writer {
 
-/// The output produced when printing MSL.
-struct PrintResult {
-    /// Constructor
-    PrintResult();
-
-    /// Destructor
-    ~PrintResult();
-
-    /// Copy constructor
-    PrintResult(const PrintResult&);
-
-    /// Copy assignment
-    /// @returns this
-    PrintResult& operator=(const PrintResult&);
-
-    /// The generated MSL.
-    std::string msl = "";
-
-    /// `true` if an invariant attribute was generated.
-    bool has_invariant_attribute = false;
-
-    /// A map from entry point name to a list of dynamic workgroup allocations.
-    /// Each element of the vector is the size of the workgroup allocation that should be created
-    /// for that index.
-    std::unordered_map<std::string, std::vector<uint32_t>> workgroup_allocations;
-};
-
 /// @param module the Tint IR module to generate
 /// @returns the result of printing the MSL shader on success, or failure
-Result<PrintResult> Print(core::ir::Module& module, const Options& options);
+Result<Output> Print(core::ir::Module& module, const Options& options);
 
 }  // namespace tint::msl::writer
 
diff --git a/src/tint/lang/msl/writer/writer.cc b/src/tint/lang/msl/writer/writer.cc
index bd4dbc2..ad36b55 100644
--- a/src/tint/lang/msl/writer/writer.cc
+++ b/src/tint/lang/msl/writer/writer.cc
@@ -53,16 +53,13 @@
         return raise_result.Failure();
     }
 
-    // Generate the MSL code.
     auto result = Print(ir, options);
     if (result != Success) {
         return result.Failure();
     }
-    output.msl = result->msl;
-    output.workgroup_allocations = std::move(result->workgroup_allocations);
-    output.needs_storage_buffer_sizes = raise_result->needs_storage_buffer_sizes;
-    output.has_invariant_attribute = result->has_invariant_attribute;
-    return output;
+
+    result->needs_storage_buffer_sizes = raise_result->needs_storage_buffer_sizes;
+    return result;
 }
 
 Result<Output> Generate(const Program& program, const Options& options) {
@@ -93,7 +90,7 @@
     }
     output.msl = impl->Result();
     output.has_invariant_attribute = impl->HasInvariant();
-    output.workgroup_allocations = impl->DynamicWorkgroupAllocations();
+    output.workgroup_info.allocations = impl->DynamicWorkgroupAllocations();
 
     return output;
 }
diff --git a/src/tint/lang/msl/writer/writer.h b/src/tint/lang/msl/writer/writer.h
index e27c318..cdfdb2c 100644
--- a/src/tint/lang/msl/writer/writer.h
+++ b/src/tint/lang/msl/writer/writer.h
@@ -31,7 +31,7 @@
 #include <string>
 
 #include "src/tint/lang/msl/writer/common/options.h"
-#include "src/tint/lang/msl/writer/output.h"
+#include "src/tint/lang/msl/writer/common/output.h"
 #include "src/tint/utils/diagnostic/diagnostic.h"
 #include "src/tint/utils/result/result.h"
 
diff --git a/src/tint/lang/msl/writer/writer_test.cc b/src/tint/lang/msl/writer/writer_test.cc
index d11d0ef..f44b553 100644
--- a/src/tint/lang/msl/writer/writer_test.cc
+++ b/src/tint/lang/msl/writer/writer_test.cc
@@ -84,11 +84,11 @@
   foo_inner(tint_local_index, tint_module_vars);
 }
 )");
-    ASSERT_EQ(output_.workgroup_allocations.size(), 2u);
-    ASSERT_EQ(output_.workgroup_allocations.count("foo"), 1u);
-    ASSERT_EQ(output_.workgroup_allocations.count("bar"), 1u);
-    EXPECT_THAT(output_.workgroup_allocations.at("foo"), testing::ElementsAre(8u));
-    EXPECT_THAT(output_.workgroup_allocations.at("bar"), testing::ElementsAre());
+    ASSERT_EQ(output_.workgroup_info.allocations.size(), 2u);
+    ASSERT_EQ(output_.workgroup_info.allocations.count("foo"), 1u);
+    ASSERT_EQ(output_.workgroup_info.allocations.count("bar"), 1u);
+    EXPECT_THAT(output_.workgroup_info.allocations.at("foo"), testing::ElementsAre(8u));
+    EXPECT_THAT(output_.workgroup_info.allocations.at("bar"), testing::ElementsAre());
 }
 
 TEST_F(MslWriterTest, NeedsStorageBufferSizes_False) {