[YUV AHB] Pass a struct to vk::ShaderModule::GetHandleAndSpirv Passing giant parameter lists was getting a bit out of hand. Bug: 468988322 Change-Id: I486d1fcbb2306262a5af7aeab826b7f1f799f4cf Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/297455 Reviewed-by: Brandon Jones <bajones@chromium.org> Commit-Queue: Corentin Wallez <cwallez@chromium.org> Reviewed-by: Kyle Charbonneau <kylechar@google.com>
diff --git a/src/dawn/native/vulkan/ComputePipelineVk.cpp b/src/dawn/native/vulkan/ComputePipelineVk.cpp index e3f27cc..2d0c20f 100644 --- a/src/dawn/native/vulkan/ComputePipelineVk.cpp +++ b/src/dawn/native/vulkan/ComputePipelineVk.cpp
@@ -83,11 +83,11 @@ ShaderModule* module = ToBackend(computeStage.module.Get()); ShaderModule::ModuleAndSpirv moduleAndSpirv; - DAWN_TRY_ASSIGN( - moduleAndSpirv, - module->GetHandleAndSpirv(SingleShaderStage::Compute, computeStage, layout, - /*emitPointSize=*/false, /*polyfillPixelCenter=*/false, - /*needsMultisampledFramebufferFetch=*/false, GetImmediateMask())); + DAWN_TRY_ASSIGN(moduleAndSpirv, module->GetHandleAndSpirv({ + .stage = &computeStage, + .layout = layout, + .immediateMask = GetImmediateMask(), + })); createInfo.stage.module = moduleAndSpirv.module; // string_view returned by GetIsolatedEntryPointName() points to a null-terminated string.
diff --git a/src/dawn/native/vulkan/RenderPipelineVk.cpp b/src/dawn/native/vulkan/RenderPipelineVk.cpp index 9d25492..853202a 100644 --- a/src/dawn/native/vulkan/RenderPipelineVk.cpp +++ b/src/dawn/native/vulkan/RenderPipelineVk.cpp
@@ -369,8 +369,6 @@ offsetof(RenderImmediateConstants, clampFragDepth), sizeof(ClampFragDepthArgs)); } - bool needsMultisampledFramebufferFetch = UseSampleRateShading() && UsesFramebufferFetch(); - // Initialize the layout after all modifications to mImmediateMask. DAWN_TRY_ASSIGN(mVkLayout, layout->GetOrCreateVkLayoutObject(mImmediateMask)); @@ -378,14 +376,10 @@ std::array<VkPipelineShaderStageCreateInfo, 2> shaderStages; uint32_t stageCount = 0; - auto AddShaderStage = [&](SingleShaderStage stage, bool emitPointSize) -> MaybeError { - const ProgrammableStage& programmableStage = GetStage(stage); + auto AddShaderStage = [&](const ShaderModule::CompileParameters& compileParams) -> MaybeError { ShaderModule::ModuleAndSpirv moduleAndSpirv; - DAWN_TRY_ASSIGN(moduleAndSpirv, ToBackend(programmableStage.module) - ->GetHandleAndSpirv(stage, programmableStage, layout, - emitPointSize, polyfillPixelCenter, - needsMultisampledFramebufferFetch, - GetImmediateMask())); + DAWN_TRY_ASSIGN(moduleAndSpirv, + ToBackend(compileParams.stage->module)->GetHandleAndSpirv(compileParams)); mHasInputAttachment = mHasInputAttachment || moduleAndSpirv.hasInputAttachment; if (buildCacheKey) { // Record cache key for each shader since it will become inaccessible later on. @@ -398,7 +392,7 @@ shaderStage->pNext = nullptr; shaderStage->flags = 0; shaderStage->pSpecializationInfo = nullptr; - shaderStage->stage = VulkanShaderStage(stage); + shaderStage->stage = VulkanShaderStage(compileParams.stage->metadata->stage); // string_view returned by GetIsolatedEntryPointName() points to a null-terminated string. shaderStage->pName = device->GetIsolatedEntryPointName().data(); @@ -407,13 +401,23 @@ }; // Add the vertex stage that's always present. - DAWN_TRY(AddShaderStage(SingleShaderStage::Vertex, - GetPrimitiveTopology() == wgpu::PrimitiveTopology::PointList)); + DAWN_TRY(AddShaderStage({ + .stage = &GetStage(SingleShaderStage::Vertex), + .layout = layout, + .immediateMask = GetImmediateMask(), + .emitPointSize = GetPrimitiveTopology() == wgpu::PrimitiveTopology::PointList, + .polyfillPixelCenter = polyfillPixelCenter, + })); // Add the fragment stage if present. if (GetStageMask() & wgpu::ShaderStage::Fragment) { - DAWN_TRY(AddShaderStage(SingleShaderStage::Fragment, - /*emitPointSize*/ false)); + DAWN_TRY(AddShaderStage({ + .stage = &GetStage(SingleShaderStage::Fragment), + .layout = layout, + .immediateMask = GetImmediateMask(), + .polyfillPixelCenter = polyfillPixelCenter, + .needsMultisampledFramebufferFetch = UseSampleRateShading() && UsesFramebufferFetch(), + })); } PipelineVertexInputStateCreateInfoTemporaryAllocations tempAllocations;
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp index 51ae3da..bfc251e 100644 --- a/src/dawn/native/vulkan/ShaderModuleVk.cpp +++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp
@@ -127,13 +127,7 @@ #endif // TINT_BUILD_SPV_WRITER ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv( - SingleShaderStage stage, - const ProgrammableStage& programmableStage, - const PipelineLayout* layout, - bool emitPointSize, - bool polyfillPixelCenter, - bool needsMultisampledFramebufferFetch, - const ImmediateConstantMask& pipelineImmediateMask) { + const CompileParameters& in) { TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleVk::GetHandleAndSpirv"); #if TINT_BUILD_SPV_WRITER @@ -143,7 +137,7 @@ // bindings for all other bindgroups by 1. BindGroupIndex startOfBindGroups{0}; std::optional<tint::ResourceTableConfig> resourceTableConfig = std::nullopt; - if (layout->UsesResourceTable()) { + if (in.layout->UsesResourceTable()) { startOfBindGroups = BindGroupIndex(1); auto bindingTypeOrder = ResourceTableDefaultResources::GetOrder(); @@ -154,8 +148,8 @@ }; } - tint::Bindings bindings = - GenerateBindingRemapping(layout, stage, [&](BindGroupIndex group, BindingIndex index) { + tint::Bindings bindings = GenerateBindingRemapping( + in.layout, in.stage->metadata->stage, [&](BindGroupIndex group, BindingIndex index) { return tint::BindingPoint{ .group = uint32_t(startOfBindGroups + group), .binding = uint32_t(index), @@ -165,8 +159,8 @@ // Post process the binding remapping to make statically paired texture point at the sampler // binding point instead. std::unordered_set<tint::BindingPoint> staticallyPairedTextureBindingPoints; - for (BindGroupIndex group : layout->GetBindGroupLayoutsMask()) { - const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group)); + for (BindGroupIndex group : in.layout->GetBindGroupLayoutsMask()) { + const BindGroupLayout* bgl = ToBackend(in.layout->GetBindGroupLayout(group)); for (BindingIndex index : bgl->GetSampledTextureIndices()) { const auto& bindingInfo = bgl->GetBindingInfo(index); @@ -183,25 +177,25 @@ const bool hasInputAttachment = !bindings.input_attachment.empty(); SpirvCompilationRequest req = {}; - req.stage = stage; + req.stage = in.stage->metadata->stage; req.shaderModuleHash = GetHash(); req.inputProgram = UnsafeUnserializedValue(UseTintProgram()); req.platform = UnsafeUnserializedValue(GetDevice()->GetPlatform()); - req.usesSubgroupMatrix = programmableStage.metadata->usesSubgroupMatrix; + req.usesSubgroupMatrix = in.stage->metadata->usesSubgroupMatrix; // TODO(464008240): Cleanup the exposing of `EnumerateSubgroupMatrixConfigs` when possible. req.subgroupMatrixConfig = ToBackend(GetDevice()->GetPhysicalDevice()) ->EnumerateSubgroupMatrixConfigs(GetDevice()->GetAdapter()->GetTogglesState()); - req.tintOptions.entry_point_name = programmableStage.entryPoint; + req.tintOptions.entry_point_name = in.stage->entryPoint; req.tintOptions.remapped_entry_point_name = GetDevice()->GetIsolatedEntryPointName(); req.tintOptions.strip_all_names = !GetDevice()->IsToggleEnabled(Toggle::DisableSymbolRenaming); req.tintOptions.statically_paired_texture_binding_points = std::move(staticallyPairedTextureBindingPoints); req.tintOptions.substitute_overrides_config = { - .map = BuildSubstituteOverridesTransformConfig(programmableStage), + .map = BuildSubstituteOverridesTransformConfig(*in.stage), }; req.tintOptions.bindings = std::move(bindings); req.tintOptions.resource_table = std::move(resourceTableConfig); @@ -217,9 +211,9 @@ req.tintOptions.disable_polyfill_integer_div_mod = GetDevice()->IsToggleEnabled(Toggle::DisablePolyfillsOnIntegerDivisonAndModulo); - req.tintOptions.emit_vertex_point_size = emitPointSize; - req.tintOptions.polyfill_pixel_center = polyfillPixelCenter; - req.tintOptions.multisampled_framebuffer_fetch = needsMultisampledFramebufferFetch; + req.tintOptions.emit_vertex_point_size = in.emitPointSize; + req.tintOptions.polyfill_pixel_center = in.polyfillPixelCenter; + req.tintOptions.multisampled_framebuffer_fetch = in.needsMultisampledFramebufferFetch; req.tintOptions.spirv_version = GetDevice()->IsToggleEnabled(Toggle::UseSpirv14) ? tint::spirv::writer::SpvVersion::kSpv14 @@ -276,9 +270,9 @@ } // Set internal immediate constant offsets - if (HasImmediateConstants(&RenderImmediateConstants::clampFragDepth, pipelineImmediateMask)) { + if (HasImmediateConstants(&RenderImmediateConstants::clampFragDepth, in.immediateMask)) { uint32_t offsetStartBytes = GetImmediateByteOffsetInPipeline( - &RenderImmediateConstants::clampFragDepth, pipelineImmediateMask); + &RenderImmediateConstants::clampFragDepth, in.immediateMask); req.tintOptions.depth_range_offsets = { offsetStartBytes, offsetStartBytes + kImmediateConstantElementByteSize}; }
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.h b/src/dawn/native/vulkan/ShaderModuleVk.h index fd713f8..d18b5d9 100644 --- a/src/dawn/native/vulkan/ShaderModuleVk.h +++ b/src/dawn/native/vulkan/ShaderModuleVk.h
@@ -52,6 +52,12 @@ class ShaderModule final : public ShaderModuleBase { public: + static ResultOrError<Ref<ShaderModule>> Create( + Device* device, + const UnpackedPtr<ShaderModuleDescriptor>& descriptor, + const std::vector<tint::wgsl::Extension>& internalExtensions); + + // Caller is responsible for destroying the `VkShaderModule` returned. struct ModuleAndSpirv { VkShaderModule module; std::vector<uint32_t> spirv; @@ -59,20 +65,18 @@ Extent3D workgroupSize; std::optional<uint32_t> explicitSubgroupSize; }; + struct CompileParameters { + // Kept without defaults as they must be provided. + raw_ptr<const ProgrammableStage> stage; + raw_ptr<const PipelineLayout> layout; + ImmediateConstantMask immediateMask; - static ResultOrError<Ref<ShaderModule>> Create( - Device* device, - const UnpackedPtr<ShaderModuleDescriptor>& descriptor, - const std::vector<tint::wgsl::Extension>& internalExtensions); + bool emitPointSize = false; + bool polyfillPixelCenter = false; + bool needsMultisampledFramebufferFetch = false; + }; - // Caller is responsible for destroying the `VkShaderModule` returned. - ResultOrError<ModuleAndSpirv> GetHandleAndSpirv(SingleShaderStage stage, - const ProgrammableStage& programmableStage, - const PipelineLayout* layout, - bool emitPointSize, - bool polyfillPixelCenter, - bool needsMultisampledFramebufferFetch, - const ImmediateConstantMask& pipelineMask); + ResultOrError<ModuleAndSpirv> GetHandleAndSpirv(const CompileParameters& p); private: ShaderModule(Device* device,