[YUV AHB] Pass a struct to vk::ShaderModule::GetHandleAndSpirv

Passing giant parameter lists was getting a bit out of hand.

Bug: 468988322
Change-Id: I486d1fcbb2306262a5af7aeab826b7f1f799f4cf
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/297455
Reviewed-by: Brandon Jones <bajones@chromium.org>
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Kyle Charbonneau <kylechar@google.com>
diff --git a/src/dawn/native/vulkan/ComputePipelineVk.cpp b/src/dawn/native/vulkan/ComputePipelineVk.cpp
index e3f27cc..2d0c20f 100644
--- a/src/dawn/native/vulkan/ComputePipelineVk.cpp
+++ b/src/dawn/native/vulkan/ComputePipelineVk.cpp
@@ -83,11 +83,11 @@
     ShaderModule* module = ToBackend(computeStage.module.Get());
 
     ShaderModule::ModuleAndSpirv moduleAndSpirv;
-    DAWN_TRY_ASSIGN(
-        moduleAndSpirv,
-        module->GetHandleAndSpirv(SingleShaderStage::Compute, computeStage, layout,
-                                  /*emitPointSize=*/false, /*polyfillPixelCenter=*/false,
-                                  /*needsMultisampledFramebufferFetch=*/false, GetImmediateMask()));
+    DAWN_TRY_ASSIGN(moduleAndSpirv, module->GetHandleAndSpirv({
+                                        .stage = &computeStage,
+                                        .layout = layout,
+                                        .immediateMask = GetImmediateMask(),
+                                    }));
 
     createInfo.stage.module = moduleAndSpirv.module;
     // string_view returned by GetIsolatedEntryPointName() points to a null-terminated string.
diff --git a/src/dawn/native/vulkan/RenderPipelineVk.cpp b/src/dawn/native/vulkan/RenderPipelineVk.cpp
index 9d25492..853202a 100644
--- a/src/dawn/native/vulkan/RenderPipelineVk.cpp
+++ b/src/dawn/native/vulkan/RenderPipelineVk.cpp
@@ -369,8 +369,6 @@
             offsetof(RenderImmediateConstants, clampFragDepth), sizeof(ClampFragDepthArgs));
     }
 
-    bool needsMultisampledFramebufferFetch = UseSampleRateShading() && UsesFramebufferFetch();
-
     // Initialize the layout after all modifications to mImmediateMask.
     DAWN_TRY_ASSIGN(mVkLayout, layout->GetOrCreateVkLayoutObject(mImmediateMask));
 
@@ -378,14 +376,10 @@
     std::array<VkPipelineShaderStageCreateInfo, 2> shaderStages;
     uint32_t stageCount = 0;
 
-    auto AddShaderStage = [&](SingleShaderStage stage, bool emitPointSize) -> MaybeError {
-        const ProgrammableStage& programmableStage = GetStage(stage);
+    auto AddShaderStage = [&](const ShaderModule::CompileParameters& compileParams) -> MaybeError {
         ShaderModule::ModuleAndSpirv moduleAndSpirv;
-        DAWN_TRY_ASSIGN(moduleAndSpirv, ToBackend(programmableStage.module)
-                                            ->GetHandleAndSpirv(stage, programmableStage, layout,
-                                                                emitPointSize, polyfillPixelCenter,
-                                                                needsMultisampledFramebufferFetch,
-                                                                GetImmediateMask()));
+        DAWN_TRY_ASSIGN(moduleAndSpirv,
+                        ToBackend(compileParams.stage->module)->GetHandleAndSpirv(compileParams));
         mHasInputAttachment = mHasInputAttachment || moduleAndSpirv.hasInputAttachment;
         if (buildCacheKey) {
             // Record cache key for each shader since it will become inaccessible later on.
@@ -398,7 +392,7 @@
         shaderStage->pNext = nullptr;
         shaderStage->flags = 0;
         shaderStage->pSpecializationInfo = nullptr;
-        shaderStage->stage = VulkanShaderStage(stage);
+        shaderStage->stage = VulkanShaderStage(compileParams.stage->metadata->stage);
         // string_view returned by GetIsolatedEntryPointName() points to a null-terminated string.
         shaderStage->pName = device->GetIsolatedEntryPointName().data();
 
@@ -407,13 +401,23 @@
     };
 
     // Add the vertex stage that's always present.
-    DAWN_TRY(AddShaderStage(SingleShaderStage::Vertex,
-                            GetPrimitiveTopology() == wgpu::PrimitiveTopology::PointList));
+    DAWN_TRY(AddShaderStage({
+        .stage = &GetStage(SingleShaderStage::Vertex),
+        .layout = layout,
+        .immediateMask = GetImmediateMask(),
+        .emitPointSize = GetPrimitiveTopology() == wgpu::PrimitiveTopology::PointList,
+        .polyfillPixelCenter = polyfillPixelCenter,
+    }));
 
     // Add the fragment stage if present.
     if (GetStageMask() & wgpu::ShaderStage::Fragment) {
-        DAWN_TRY(AddShaderStage(SingleShaderStage::Fragment,
-                                /*emitPointSize*/ false));
+        DAWN_TRY(AddShaderStage({
+            .stage = &GetStage(SingleShaderStage::Fragment),
+            .layout = layout,
+            .immediateMask = GetImmediateMask(),
+            .polyfillPixelCenter = polyfillPixelCenter,
+            .needsMultisampledFramebufferFetch = UseSampleRateShading() && UsesFramebufferFetch(),
+        }));
     }
 
     PipelineVertexInputStateCreateInfoTemporaryAllocations tempAllocations;
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp
index 51ae3da..bfc251e 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp
@@ -127,13 +127,7 @@
 #endif  // TINT_BUILD_SPV_WRITER
 
 ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
-    SingleShaderStage stage,
-    const ProgrammableStage& programmableStage,
-    const PipelineLayout* layout,
-    bool emitPointSize,
-    bool polyfillPixelCenter,
-    bool needsMultisampledFramebufferFetch,
-    const ImmediateConstantMask& pipelineImmediateMask) {
+    const CompileParameters& in) {
     TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleVk::GetHandleAndSpirv");
 
 #if TINT_BUILD_SPV_WRITER
@@ -143,7 +137,7 @@
     // bindings for all other bindgroups by 1.
     BindGroupIndex startOfBindGroups{0};
     std::optional<tint::ResourceTableConfig> resourceTableConfig = std::nullopt;
-    if (layout->UsesResourceTable()) {
+    if (in.layout->UsesResourceTable()) {
         startOfBindGroups = BindGroupIndex(1);
 
         auto bindingTypeOrder = ResourceTableDefaultResources::GetOrder();
@@ -154,8 +148,8 @@
         };
     }
 
-    tint::Bindings bindings =
-        GenerateBindingRemapping(layout, stage, [&](BindGroupIndex group, BindingIndex index) {
+    tint::Bindings bindings = GenerateBindingRemapping(
+        in.layout, in.stage->metadata->stage, [&](BindGroupIndex group, BindingIndex index) {
             return tint::BindingPoint{
                 .group = uint32_t(startOfBindGroups + group),
                 .binding = uint32_t(index),
@@ -165,8 +159,8 @@
     // Post process the binding remapping to make statically paired texture point at the sampler
     // binding point instead.
     std::unordered_set<tint::BindingPoint> staticallyPairedTextureBindingPoints;
-    for (BindGroupIndex group : layout->GetBindGroupLayoutsMask()) {
-        const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
+    for (BindGroupIndex group : in.layout->GetBindGroupLayoutsMask()) {
+        const BindGroupLayout* bgl = ToBackend(in.layout->GetBindGroupLayout(group));
 
         for (BindingIndex index : bgl->GetSampledTextureIndices()) {
             const auto& bindingInfo = bgl->GetBindingInfo(index);
@@ -183,25 +177,25 @@
     const bool hasInputAttachment = !bindings.input_attachment.empty();
 
     SpirvCompilationRequest req = {};
-    req.stage = stage;
+    req.stage = in.stage->metadata->stage;
     req.shaderModuleHash = GetHash();
     req.inputProgram = UnsafeUnserializedValue(UseTintProgram());
     req.platform = UnsafeUnserializedValue(GetDevice()->GetPlatform());
-    req.usesSubgroupMatrix = programmableStage.metadata->usesSubgroupMatrix;
+    req.usesSubgroupMatrix = in.stage->metadata->usesSubgroupMatrix;
 
     // TODO(464008240): Cleanup the exposing of `EnumerateSubgroupMatrixConfigs` when possible.
     req.subgroupMatrixConfig =
         ToBackend(GetDevice()->GetPhysicalDevice())
             ->EnumerateSubgroupMatrixConfigs(GetDevice()->GetAdapter()->GetTogglesState());
 
-    req.tintOptions.entry_point_name = programmableStage.entryPoint;
+    req.tintOptions.entry_point_name = in.stage->entryPoint;
     req.tintOptions.remapped_entry_point_name = GetDevice()->GetIsolatedEntryPointName();
     req.tintOptions.strip_all_names = !GetDevice()->IsToggleEnabled(Toggle::DisableSymbolRenaming);
 
     req.tintOptions.statically_paired_texture_binding_points =
         std::move(staticallyPairedTextureBindingPoints);
     req.tintOptions.substitute_overrides_config = {
-        .map = BuildSubstituteOverridesTransformConfig(programmableStage),
+        .map = BuildSubstituteOverridesTransformConfig(*in.stage),
     };
     req.tintOptions.bindings = std::move(bindings);
     req.tintOptions.resource_table = std::move(resourceTableConfig);
@@ -217,9 +211,9 @@
     req.tintOptions.disable_polyfill_integer_div_mod =
         GetDevice()->IsToggleEnabled(Toggle::DisablePolyfillsOnIntegerDivisonAndModulo);
 
-    req.tintOptions.emit_vertex_point_size = emitPointSize;
-    req.tintOptions.polyfill_pixel_center = polyfillPixelCenter;
-    req.tintOptions.multisampled_framebuffer_fetch = needsMultisampledFramebufferFetch;
+    req.tintOptions.emit_vertex_point_size = in.emitPointSize;
+    req.tintOptions.polyfill_pixel_center = in.polyfillPixelCenter;
+    req.tintOptions.multisampled_framebuffer_fetch = in.needsMultisampledFramebufferFetch;
 
     req.tintOptions.spirv_version = GetDevice()->IsToggleEnabled(Toggle::UseSpirv14)
                                         ? tint::spirv::writer::SpvVersion::kSpv14
@@ -276,9 +270,9 @@
     }
 
     // Set internal immediate constant offsets
-    if (HasImmediateConstants(&RenderImmediateConstants::clampFragDepth, pipelineImmediateMask)) {
+    if (HasImmediateConstants(&RenderImmediateConstants::clampFragDepth, in.immediateMask)) {
         uint32_t offsetStartBytes = GetImmediateByteOffsetInPipeline(
-            &RenderImmediateConstants::clampFragDepth, pipelineImmediateMask);
+            &RenderImmediateConstants::clampFragDepth, in.immediateMask);
         req.tintOptions.depth_range_offsets = {
             offsetStartBytes, offsetStartBytes + kImmediateConstantElementByteSize};
     }
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.h b/src/dawn/native/vulkan/ShaderModuleVk.h
index fd713f8..d18b5d9 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.h
+++ b/src/dawn/native/vulkan/ShaderModuleVk.h
@@ -52,6 +52,12 @@
 
 class ShaderModule final : public ShaderModuleBase {
   public:
+    static ResultOrError<Ref<ShaderModule>> Create(
+        Device* device,
+        const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+        const std::vector<tint::wgsl::Extension>& internalExtensions);
+
+    // Caller is responsible for destroying the `VkShaderModule` returned.
     struct ModuleAndSpirv {
         VkShaderModule module;
         std::vector<uint32_t> spirv;
@@ -59,20 +65,18 @@
         Extent3D workgroupSize;
         std::optional<uint32_t> explicitSubgroupSize;
     };
+    struct CompileParameters {
+        // Kept without defaults as they must be provided.
+        raw_ptr<const ProgrammableStage> stage;
+        raw_ptr<const PipelineLayout> layout;
+        ImmediateConstantMask immediateMask;
 
-    static ResultOrError<Ref<ShaderModule>> Create(
-        Device* device,
-        const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-        const std::vector<tint::wgsl::Extension>& internalExtensions);
+        bool emitPointSize = false;
+        bool polyfillPixelCenter = false;
+        bool needsMultisampledFramebufferFetch = false;
+    };
 
-    // Caller is responsible for destroying the `VkShaderModule` returned.
-    ResultOrError<ModuleAndSpirv> GetHandleAndSpirv(SingleShaderStage stage,
-                                                    const ProgrammableStage& programmableStage,
-                                                    const PipelineLayout* layout,
-                                                    bool emitPointSize,
-                                                    bool polyfillPixelCenter,
-                                                    bool needsMultisampledFramebufferFetch,
-                                                    const ImmediateConstantMask& pipelineMask);
+    ResultOrError<ModuleAndSpirv> GetHandleAndSpirv(const CompileParameters& p);
 
   private:
     ShaderModule(Device* device,