Update MSL Generator to return workgroup information.
This CL adds an output to the MSL generator to return the workgroup
size information after `SubstituteOverrides` has been run.
Bug: 380043961
Change-Id: I75dad66bed3afeabdc18756f75b1e109b2a5344b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/218856
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/dawn/native/metal/ShaderModuleMTL.mm b/src/dawn/native/metal/ShaderModuleMTL.mm
index bd5ead5..29d4380 100644
--- a/src/dawn/native/metal/ShaderModuleMTL.mm
+++ b/src/dawn/native/metal/ShaderModuleMTL.mm
@@ -375,7 +375,7 @@
)" + msl;
auto workgroupAllocations =
- std::move(result->workgroup_allocations.at(kRemappedEntryPointName));
+ std::move(result->workgroup_info.allocations.at(kRemappedEntryPointName));
return MslCompilation{{
std::move(msl),
std::move(kRemappedEntryPointName),
diff --git a/src/tint/lang/msl/writer/BUILD.bazel b/src/tint/lang/msl/writer/BUILD.bazel
index 766902e..489aadd 100644
--- a/src/tint/lang/msl/writer/BUILD.bazel
+++ b/src/tint/lang/msl/writer/BUILD.bazel
@@ -39,11 +39,9 @@
cc_library(
name = "writer",
srcs = [
- "output.cc",
"writer.cc",
],
hdrs = [
- "output.h",
"writer.h",
],
deps = [
diff --git a/src/tint/lang/msl/writer/BUILD.cmake b/src/tint/lang/msl/writer/BUILD.cmake
index fe540d4..536378a 100644
--- a/src/tint/lang/msl/writer/BUILD.cmake
+++ b/src/tint/lang/msl/writer/BUILD.cmake
@@ -48,8 +48,6 @@
# Condition: TINT_BUILD_MSL_WRITER
################################################################################
tint_add_target(tint_lang_msl_writer lib
- lang/msl/writer/output.cc
- lang/msl/writer/output.h
lang/msl/writer/writer.cc
lang/msl/writer/writer.h
)
diff --git a/src/tint/lang/msl/writer/BUILD.gn b/src/tint/lang/msl/writer/BUILD.gn
index 118d78a..9a2501bd 100644
--- a/src/tint/lang/msl/writer/BUILD.gn
+++ b/src/tint/lang/msl/writer/BUILD.gn
@@ -45,8 +45,6 @@
if (tint_build_msl_writer) {
libtint_source_set("writer") {
sources = [
- "output.cc",
- "output.h",
"writer.cc",
"writer.h",
]
diff --git a/src/tint/lang/msl/writer/common/BUILD.bazel b/src/tint/lang/msl/writer/common/BUILD.bazel
index 282ca55..89e8002 100644
--- a/src/tint/lang/msl/writer/common/BUILD.bazel
+++ b/src/tint/lang/msl/writer/common/BUILD.bazel
@@ -41,11 +41,13 @@
srcs = [
"option_helpers.cc",
"options.cc",
+ "output.cc",
"printer_support.cc",
],
hdrs = [
"option_helpers.h",
"options.h",
+ "output.h",
"printer_support.h",
],
deps = [
diff --git a/src/tint/lang/msl/writer/common/BUILD.cmake b/src/tint/lang/msl/writer/common/BUILD.cmake
index fd56a5e..bd173cc 100644
--- a/src/tint/lang/msl/writer/common/BUILD.cmake
+++ b/src/tint/lang/msl/writer/common/BUILD.cmake
@@ -45,6 +45,8 @@
lang/msl/writer/common/option_helpers.h
lang/msl/writer/common/options.cc
lang/msl/writer/common/options.h
+ lang/msl/writer/common/output.cc
+ lang/msl/writer/common/output.h
lang/msl/writer/common/printer_support.cc
lang/msl/writer/common/printer_support.h
)
diff --git a/src/tint/lang/msl/writer/common/BUILD.gn b/src/tint/lang/msl/writer/common/BUILD.gn
index 8567732..a1fb165 100644
--- a/src/tint/lang/msl/writer/common/BUILD.gn
+++ b/src/tint/lang/msl/writer/common/BUILD.gn
@@ -49,6 +49,8 @@
"option_helpers.h",
"options.cc",
"options.h",
+ "output.cc",
+ "output.h",
"printer_support.cc",
"printer_support.h",
]
diff --git a/src/tint/lang/msl/writer/output.cc b/src/tint/lang/msl/writer/common/output.cc
similarity index 96%
rename from src/tint/lang/msl/writer/output.cc
rename to src/tint/lang/msl/writer/common/output.cc
index cc2f3a0..14cf191 100644
--- a/src/tint/lang/msl/writer/output.cc
+++ b/src/tint/lang/msl/writer/common/output.cc
@@ -25,7 +25,7 @@
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-#include "src/tint/lang/msl/writer/output.h"
+#include "src/tint/lang/msl/writer/common/output.h"
namespace tint::msl::writer {
diff --git a/src/tint/lang/msl/writer/output.h b/src/tint/lang/msl/writer/common/output.h
similarity index 74%
rename from src/tint/lang/msl/writer/output.h
rename to src/tint/lang/msl/writer/common/output.h
index edcd2b2d..0d933df 100644
--- a/src/tint/lang/msl/writer/output.h
+++ b/src/tint/lang/msl/writer/common/output.h
@@ -25,13 +25,12 @@
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-#ifndef SRC_TINT_LANG_MSL_WRITER_OUTPUT_H_
-#define SRC_TINT_LANG_MSL_WRITER_OUTPUT_H_
+#ifndef SRC_TINT_LANG_MSL_WRITER_COMMON_OUTPUT_H_
+#define SRC_TINT_LANG_MSL_WRITER_COMMON_OUTPUT_H_
#include <cstdint>
#include <string>
#include <unordered_map>
-#include <unordered_set>
#include <vector>
namespace tint::msl::writer {
@@ -51,6 +50,21 @@
/// @returns this
Output& operator=(const Output&);
+ /// Workgroup size information
+ struct WorkgroupInfo {
+ /// The x-component
+ uint32_t x = 0;
+ /// The y-component
+ uint32_t y = 0;
+ /// The z-component
+ uint32_t z = 0;
+
+ /// A map from entry point name to a list of dynamic workgroup allocations.
+ /// Each entry in the vector is the size of the workgroup allocation that
+ /// should be created for that index.
+ std::unordered_map<std::string, std::vector<uint32_t>> allocations;
+ };
+
/// The generated MSL.
std::string msl = "";
@@ -60,12 +74,10 @@
/// True if the generated shader uses the invariant attribute.
bool has_invariant_attribute = false;
- /// A map from entry point name to a list of dynamic workgroup allocations.
- /// Each entry in the vector is the size of the workgroup allocation that
- /// should be created for that index.
- std::unordered_map<std::string, std::vector<uint32_t>> workgroup_allocations;
+ /// The workgroup size information, if the entry point was a compute shader
+ WorkgroupInfo workgroup_info{};
};
} // namespace tint::msl::writer
-#endif // SRC_TINT_LANG_MSL_WRITER_OUTPUT_H_
+#endif // SRC_TINT_LANG_MSL_WRITER_COMMON_OUTPUT_H_
diff --git a/src/tint/lang/msl/writer/function_test.cc b/src/tint/lang/msl/writer/function_test.cc
index d223271..96c6d9f 100644
--- a/src/tint/lang/msl/writer/function_test.cc
+++ b/src/tint/lang/msl/writer/function_test.cc
@@ -28,7 +28,8 @@
#include "src/tint/lang/core/type/sampled_texture.h"
#include "src/tint/lang/msl/writer/helper_test.h"
-using namespace tint::core::fluent_types; // NOLINT
+using namespace tint::core::fluent_types; // NOLINT
+using namespace tint::core::number_suffixes; // NOLINT
namespace tint::msl::writer {
namespace {
@@ -42,6 +43,28 @@
void foo() {
}
)");
+
+ // MSL doesn't inject an empty entry point, so in this case there is no result.
+ EXPECT_EQ(output_.workgroup_info.x, 0u);
+ EXPECT_EQ(output_.workgroup_info.y, 0u);
+ EXPECT_EQ(output_.workgroup_info.z, 0u);
+}
+
+TEST_F(MslWriterTest, Function_EntryPoint_Compute) {
+ auto* func = b.ComputeFunction("cmp_main", 32_u, 4_u, 1_u);
+ b.Append(func->Block(), [&] { //
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.msl;
+ EXPECT_EQ(output_.msl, MetalHeader() + R"(
+kernel void cmp_main() {
+}
+)");
+
+ EXPECT_EQ(output_.workgroup_info.x, 32u);
+ EXPECT_EQ(output_.workgroup_info.y, 4u);
+ EXPECT_EQ(output_.workgroup_info.z, 1u);
}
TEST_F(MslWriterTest, EntryPointParameterBufferBindingPoint) {
@@ -70,6 +93,9 @@
tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.storage_var=storage_var, .uniform_var=uniform_var};
}
)");
+ EXPECT_EQ(output_.workgroup_info.x, 0u);
+ EXPECT_EQ(output_.workgroup_info.y, 0u);
+ EXPECT_EQ(output_.workgroup_info.z, 0u);
}
TEST_F(MslWriterTest, EntryPointParameterHandleBindingPoint) {
diff --git a/src/tint/lang/msl/writer/printer/printer.cc b/src/tint/lang/msl/writer/printer/printer.cc
index ab1dd56..a842976 100644
--- a/src/tint/lang/msl/writer/printer/printer.cc
+++ b/src/tint/lang/msl/writer/printer/printer.cc
@@ -119,7 +119,7 @@
: ir_(module), options_(options) {}
/// @returns the generated MSL shader
- tint::Result<PrintResult> Generate() {
+ tint::Result<Output> Generate() {
auto valid = core::ir::ValidateAndDumpIfNeeded(
ir_, "msl.Printer",
core::ir::Capabilities{
@@ -157,7 +157,7 @@
private:
/// The result of printing the module.
- PrintResult result_;
+ Output result_;
/// Map of builtin structure to unique generated name
Hashmap<const core::type::Struct*, std::string, 4> builtin_struct_names_;
@@ -319,9 +319,20 @@
auto func_name = NameOf(func);
switch (func->Stage()) {
- case core::ir::Function::PipelineStage::kCompute:
+ case core::ir::Function::PipelineStage::kCompute: {
out << "kernel ";
+
+ auto const_wg_size = func->WorkgroupSizeAsConst();
+ TINT_ASSERT(const_wg_size);
+ auto wg_size = *const_wg_size;
+
+ // Store the workgroup information away to return from the generator.
+ result_.workgroup_info.x = wg_size[0];
+ result_.workgroup_info.y = wg_size[1];
+ result_.workgroup_info.z = wg_size[2];
+
break;
+ }
case core::ir::Function::PipelineStage::kFragment:
out << "fragment ";
break;
@@ -332,7 +343,7 @@
break;
}
if (func->Stage() != core::ir::Function::PipelineStage::kUndefined) {
- result_.workgroup_allocations.insert({func_name, {}});
+ result_.workgroup_info.allocations.insert({func_name, {}});
}
EmitType(out, func->ReturnType());
@@ -395,7 +406,7 @@
}
if (ptr && ptr->AddressSpace() == core::AddressSpace::kWorkgroup &&
func->Stage() == core::ir::Function::PipelineStage::kCompute) {
- auto& allocations = result_.workgroup_allocations.at(func_name);
+ auto& allocations = result_.workgroup_info.allocations.at(func_name);
out << " [[threadgroup(" << allocations.size() << ")]]";
allocations.push_back(ptr->StoreType()->Size());
}
@@ -1765,16 +1776,8 @@
} // namespace
-Result<PrintResult> Print(core::ir::Module& module, const Options& options) {
+Result<Output> Print(core::ir::Module& module, const Options& options) {
return Printer{module, options}.Generate();
}
-PrintResult::PrintResult() = default;
-
-PrintResult::~PrintResult() = default;
-
-PrintResult::PrintResult(const PrintResult&) = default;
-
-PrintResult& PrintResult::operator=(const PrintResult&) = default;
-
} // namespace tint::msl::writer
diff --git a/src/tint/lang/msl/writer/printer/printer.h b/src/tint/lang/msl/writer/printer/printer.h
index 18d49fd..4cbdbd5 100644
--- a/src/tint/lang/msl/writer/printer/printer.h
+++ b/src/tint/lang/msl/writer/printer/printer.h
@@ -33,6 +33,7 @@
#include <vector>
#include "src/tint/lang/msl/writer/common/options.h"
+#include "src/tint/lang/msl/writer/common/output.h"
#include "src/tint/utils/result/result.h"
// Forward declarations
@@ -42,36 +43,9 @@
namespace tint::msl::writer {
-/// The output produced when printing MSL.
-struct PrintResult {
- /// Constructor
- PrintResult();
-
- /// Destructor
- ~PrintResult();
-
- /// Copy constructor
- PrintResult(const PrintResult&);
-
- /// Copy assignment
- /// @returns this
- PrintResult& operator=(const PrintResult&);
-
- /// The generated MSL.
- std::string msl = "";
-
- /// `true` if an invariant attribute was generated.
- bool has_invariant_attribute = false;
-
- /// A map from entry point name to a list of dynamic workgroup allocations.
- /// Each element of the vector is the size of the workgroup allocation that should be created
- /// for that index.
- std::unordered_map<std::string, std::vector<uint32_t>> workgroup_allocations;
-};
-
/// @param module the Tint IR module to generate
/// @returns the result of printing the MSL shader on success, or failure
-Result<PrintResult> Print(core::ir::Module& module, const Options& options);
+Result<Output> Print(core::ir::Module& module, const Options& options);
} // namespace tint::msl::writer
diff --git a/src/tint/lang/msl/writer/writer.cc b/src/tint/lang/msl/writer/writer.cc
index bd4dbc2..ad36b55 100644
--- a/src/tint/lang/msl/writer/writer.cc
+++ b/src/tint/lang/msl/writer/writer.cc
@@ -53,16 +53,13 @@
return raise_result.Failure();
}
- // Generate the MSL code.
auto result = Print(ir, options);
if (result != Success) {
return result.Failure();
}
- output.msl = result->msl;
- output.workgroup_allocations = std::move(result->workgroup_allocations);
- output.needs_storage_buffer_sizes = raise_result->needs_storage_buffer_sizes;
- output.has_invariant_attribute = result->has_invariant_attribute;
- return output;
+
+ result->needs_storage_buffer_sizes = raise_result->needs_storage_buffer_sizes;
+ return result;
}
Result<Output> Generate(const Program& program, const Options& options) {
@@ -93,7 +90,7 @@
}
output.msl = impl->Result();
output.has_invariant_attribute = impl->HasInvariant();
- output.workgroup_allocations = impl->DynamicWorkgroupAllocations();
+ output.workgroup_info.allocations = impl->DynamicWorkgroupAllocations();
return output;
}
diff --git a/src/tint/lang/msl/writer/writer.h b/src/tint/lang/msl/writer/writer.h
index e27c318..cdfdb2c 100644
--- a/src/tint/lang/msl/writer/writer.h
+++ b/src/tint/lang/msl/writer/writer.h
@@ -31,7 +31,7 @@
#include <string>
#include "src/tint/lang/msl/writer/common/options.h"
-#include "src/tint/lang/msl/writer/output.h"
+#include "src/tint/lang/msl/writer/common/output.h"
#include "src/tint/utils/diagnostic/diagnostic.h"
#include "src/tint/utils/result/result.h"
diff --git a/src/tint/lang/msl/writer/writer_test.cc b/src/tint/lang/msl/writer/writer_test.cc
index d11d0ef..f44b553 100644
--- a/src/tint/lang/msl/writer/writer_test.cc
+++ b/src/tint/lang/msl/writer/writer_test.cc
@@ -84,11 +84,11 @@
foo_inner(tint_local_index, tint_module_vars);
}
)");
- ASSERT_EQ(output_.workgroup_allocations.size(), 2u);
- ASSERT_EQ(output_.workgroup_allocations.count("foo"), 1u);
- ASSERT_EQ(output_.workgroup_allocations.count("bar"), 1u);
- EXPECT_THAT(output_.workgroup_allocations.at("foo"), testing::ElementsAre(8u));
- EXPECT_THAT(output_.workgroup_allocations.at("bar"), testing::ElementsAre());
+ ASSERT_EQ(output_.workgroup_info.allocations.size(), 2u);
+ ASSERT_EQ(output_.workgroup_info.allocations.count("foo"), 1u);
+ ASSERT_EQ(output_.workgroup_info.allocations.count("bar"), 1u);
+ EXPECT_THAT(output_.workgroup_info.allocations.at("foo"), testing::ElementsAre(8u));
+ EXPECT_THAT(output_.workgroup_info.allocations.at("bar"), testing::ElementsAre());
}
TEST_F(MslWriterTest, NeedsStorageBufferSizes_False) {