[msl] Move VertexPulling into the backend
It must come before BindingRemapper and Robustness.
Fixed: 380044409
Change-Id: I6ef4b37780323ceee4976e20997f39eee8fbde1e
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/222021
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/dawn/native/BUILD.gn b/src/dawn/native/BUILD.gn
index 10c31cc..a418a8f 100644
--- a/src/dawn/native/BUILD.gn
+++ b/src/dawn/native/BUILD.gn
@@ -677,8 +677,6 @@
"metal/UtilsMetal.mm",
]
- deps += [ "${dawn_root}/src/tint/lang/core/ir/transform" ]
-
# If a "build with ARC" config is present, remove it.
if (filter_include(configs, [ "//build/config/compiler:enable_arc" ]) !=
[]) {
diff --git a/src/dawn/native/metal/ShaderModuleMTL.mm b/src/dawn/native/metal/ShaderModuleMTL.mm
index 940e39f..6b78950 100644
--- a/src/dawn/native/metal/ShaderModuleMTL.mm
+++ b/src/dawn/native/metal/ShaderModuleMTL.mm
@@ -44,7 +44,6 @@
#include "dawn/platform/DawnPlatform.h"
#include "dawn/platform/metrics/HistogramMacros.h"
#include "dawn/platform/tracing/TraceEvent.h"
-#include "src/tint/lang/core/ir/transform/vertex_pulling.h"
#include <tint/tint.h>
@@ -61,7 +60,6 @@
#define MSL_COMPILATION_REQUEST_MEMBERS(X) \
X(SingleShaderStage, stage) \
X(const tint::Program*, inputProgram) \
- X(OptionalVertexPullingTransformConfig, vertexPullingTransformConfig) \
X(std::optional<tint::ast::transform::SubstituteOverride::Config>, substituteOverrideConfig) \
X(LimitsForCompilationRequest, limits) \
X(CacheKey::UnsafeUnkeyedValue<const AdapterBase*>, adapter) \
@@ -274,7 +272,6 @@
req.stage = stage;
auto tintProgram = programmableStage.module->GetTintProgram();
req.inputProgram = &(tintProgram->program);
- req.vertexPullingTransformConfig = std::move(vertexPullingTransformConfig);
req.substituteOverrideConfig = std::move(substituteOverrideConfig);
req.entryPointName = programmableStage.entryPoint.c_str();
req.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming);
@@ -299,6 +296,7 @@
req.tintOptions.bindings = std::move(bindings);
req.tintOptions.disable_polyfill_integer_div_mod =
device->IsToggleEnabled(Toggle::DisablePolyfillsOnIntegerDivisonAndModulo);
+ req.tintOptions.vertex_pulling_config = std::move(vertexPullingTransformConfig);
const CombinedLimits& limits = device->GetLimits();
req.limits = LimitsForCompilationRequest::Create(limits.v1);
@@ -326,15 +324,6 @@
r.disableSymbolRenaming ? tint::ast::transform::Renamer::Target::kMslKeywords
: tint::ast::transform::Renamer::Target::kAll,
std::move(requestedNames));
-
- if (r.vertexPullingTransformConfig) {
- tint::ast::transform::VertexPulling::Config config;
- config.pulling_group = r.vertexPullingTransformConfig->pulling_group;
- config.vertex_state = r.vertexPullingTransformConfig->vertex_state;
- transformManager.Add<tint::ast::transform::VertexPulling>();
- transformInputs.Add<tint::ast::transform::VertexPulling::Config>(
- std::move(config));
- }
}
if (r.substituteOverrideConfig) {
@@ -365,15 +354,6 @@
"An error occurred while generating Tint IR\n%s",
ir.Failure().reason.Str());
- // TODO(380044409): Move this into the backend.
- if (r.vertexPullingTransformConfig) {
- auto vertex_pulling_result = tint::core::ir::transform::VertexPulling(
- ir.Get(), *r.vertexPullingTransformConfig);
- DAWN_INVALID_IF(vertex_pulling_result != tint::Success,
- "An error occurred while running vertex pulling:\n%s",
- vertex_pulling_result.Failure().reason.Str());
- }
-
result = tint::msl::writer::Generate(ir.Get(), r.tintOptions);
// Workgroup validation has to come after `Generate` because it may require
diff --git a/src/tint/lang/msl/writer/ast_printer/ast_printer.cc b/src/tint/lang/msl/writer/ast_printer/ast_printer.cc
index 7236d60..39c1e69 100644
--- a/src/tint/lang/msl/writer/ast_printer/ast_printer.cc
+++ b/src/tint/lang/msl/writer/ast_printer/ast_printer.cc
@@ -85,6 +85,7 @@
#include "src/tint/lang/wgsl/ast/transform/simplify_pointers.h"
#include "src/tint/lang/wgsl/ast/transform/unshadow.h"
#include "src/tint/lang/wgsl/ast/transform/vectorize_scalar_matrix_initializers.h"
+#include "src/tint/lang/wgsl/ast/transform/vertex_pulling.h"
#include "src/tint/lang/wgsl/ast/transform/zero_init_workgroup_memory.h"
#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
#include "src/tint/lang/wgsl/helpers/check_supported_extensions.h"
@@ -162,6 +163,15 @@
manager.Add<ast::transform::PromoteSideEffectsToDecl>();
+ // VertexPulling must come before Robustness.
+ if (options.vertex_pulling_config) {
+ ast::transform::VertexPulling::Config config;
+ config.pulling_group = options.vertex_pulling_config->pulling_group;
+ config.vertex_state = options.vertex_pulling_config->vertex_state;
+ manager.Add<ast::transform::VertexPulling>();
+ data.Add<ast::transform::VertexPulling::Config>(std::move(config));
+ }
+
if (!options.disable_robustness) {
// Robustness must come after PromoteSideEffectsToDecl
// Robustness must come before BuiltinPolyfill and CanonicalizeEntryPointIO
diff --git a/src/tint/lang/msl/writer/common/options.h b/src/tint/lang/msl/writer/common/options.h
index bd13057..5bc74c2 100644
--- a/src/tint/lang/msl/writer/common/options.h
+++ b/src/tint/lang/msl/writer/common/options.h
@@ -33,6 +33,7 @@
#include <unordered_map>
#include "src/tint/api/common/binding_point.h"
+#include "src/tint/api/common/vertex_pulling_config.h"
#include "src/tint/utils/reflection.h"
namespace tint::msl::writer {
@@ -174,7 +175,10 @@
/// from which to load buffer sizes.
ArrayLengthFromUniformOptions array_length_from_uniform = {};
- /// The bindings
+ /// The optional vertex pulling configuration.
+ std::optional<VertexPullingConfig> vertex_pulling_config = {};
+
+ /// The bindings.
Bindings bindings;
/// Reflect the fields of this class so that it can be used by tint::ForeachField()
@@ -190,6 +194,7 @@
fixed_sample_mask,
pixel_local_attachments,
array_length_from_uniform,
+ vertex_pulling_config,
bindings);
};
diff --git a/src/tint/lang/msl/writer/raise/raise.cc b/src/tint/lang/msl/writer/raise/raise.cc
index 7a8533f..3e94640 100644
--- a/src/tint/lang/msl/writer/raise/raise.cc
+++ b/src/tint/lang/msl/writer/raise/raise.cc
@@ -45,6 +45,7 @@
#include "src/tint/lang/core/ir/transform/robustness.h"
#include "src/tint/lang/core/ir/transform/value_to_let.h"
#include "src/tint/lang/core/ir/transform/vectorize_scalar_matrix_constructors.h"
+#include "src/tint/lang/core/ir/transform/vertex_pulling.h"
#include "src/tint/lang/core/ir/transform/zero_init_workgroup_memory.h"
#include "src/tint/lang/msl/writer/common/option_helpers.h"
#include "src/tint/lang/msl/writer/raise/binary_polyfill.h"
@@ -68,6 +69,11 @@
RaiseResult raise_result;
+ // VertexPulling must come before BindingRemapper and Robustness.
+ if (options.vertex_pulling_config) {
+ RUN_TRANSFORM(core::ir::transform::VertexPulling, module, *options.vertex_pulling_config);
+ }
+
tint::transform::multiplanar::BindingsMap multiplanar_map{};
RemapperData remapper_data{};
ArrayLengthFromUniformOptions array_length_from_uniform_options{};
diff --git a/src/tint/lang/msl/writer/writer_test.cc b/src/tint/lang/msl/writer/writer_test.cc
index 7826186..1cb9e27 100644
--- a/src/tint/lang/msl/writer/writer_test.cc
+++ b/src/tint/lang/msl/writer/writer_test.cc
@@ -231,5 +231,65 @@
)");
}
+TEST_F(MslWriterTest, VertexPulling) {
+ auto* ep = b.Function("main", ty.vec4<f32>(), core::ir::Function::PipelineStage::kVertex);
+ ep->SetReturnBuiltin(core::BuiltinValue::kPosition);
+ auto* attr = b.FunctionParam<vec4<f32>>("attr");
+ attr->SetLocation(1);
+ ep->SetParams({attr});
+ b.Append(ep->Block(), [&] { //
+ b.Return(ep, attr);
+ });
+
+ VertexPullingConfig vertex_pulling_config;
+ vertex_pulling_config.pulling_group = 4u;
+ vertex_pulling_config.vertex_state = {
+ {{4, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 1}}}}};
+ ArrayLengthFromUniformOptions array_length_config;
+ array_length_config.ubo_binding = 30u;
+ array_length_config.bindpoint_to_size_index.insert({BindingPoint{0u, 1u}, 0u});
+ Options options;
+ options.bindings.storage.emplace(BindingPoint{4u, 0u}, tint::msl::writer::binding::Storage{1u});
+ options.vertex_pulling_config = std::move(vertex_pulling_config);
+ options.array_length_from_uniform = std::move(array_length_config);
+
+ ASSERT_TRUE(Generate(options)) << err_ << output_.msl;
+ EXPECT_EQ(output_.msl, R"(#include <metal_stdlib>
+using namespace metal;
+
+template<typename T, size_t N>
+struct tint_array {
+ const constant T& operator[](size_t i) const constant { return elements[i]; }
+ device T& operator[](size_t i) device { return elements[i]; }
+ const device T& operator[](size_t i) const device { return elements[i]; }
+ thread T& operator[](size_t i) thread { return elements[i]; }
+ const thread T& operator[](size_t i) const thread { return elements[i]; }
+ threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
+ const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
+ T elements[N];
+};
+
+struct tint_module_vars_struct {
+ const device tint_array<uint, 1>* tint_vertex_buffer_0;
+ const constant tint_array<uint4, 1>* tint_storage_buffer_sizes;
+};
+
+struct main_outputs {
+ float4 tint_symbol [[position]];
+};
+
+float4 main_inner(uint tint_vertex_index, tint_module_vars_struct tint_module_vars) {
+ return float4(as_type<float>((*tint_module_vars.tint_vertex_buffer_0)[min(tint_vertex_index, (((*tint_module_vars.tint_storage_buffer_sizes)[0u].x / 4u) - 1u))]), 0.0f, 0.0f, 1.0f);
+}
+
+vertex main_outputs v(uint tint_vertex_index [[vertex_id]], const device tint_array<uint, 1>* tint_vertex_buffer_0 [[buffer(1)]], const constant tint_array<uint4, 1>* tint_storage_buffer_sizes [[buffer(30)]]) {
+ tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.tint_vertex_buffer_0=tint_vertex_buffer_0, .tint_storage_buffer_sizes=tint_storage_buffer_sizes};
+ main_outputs tint_wrapper_result = {};
+ tint_wrapper_result.tint_symbol = main_inner(tint_vertex_index, tint_module_vars);
+ return tint_wrapper_result;
+}
+)");
+}
+
} // namespace
} // namespace tint::msl::writer