[Static samplers] Make BindingInfo variant This CL adds a BindingInfo variant for static samplers and does an initial pass at handling the variant everywhere that it needs to be handled. Backend implementations are stubbed out. Change-Id: If2c16cbcedce64fd800d7163b3a0eb804d61033c Bug: dawn:2643 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/181360 Reviewed-by: Austin Eng <enga@chromium.org> Reviewed-by: Nicolette Prevost <nicolettep@google.com> Reviewed-by: Corentin Wallez <cwallez@chromium.org> Commit-Queue: Colin Blundell <blundell@chromium.org>
diff --git a/src/dawn/native/BindGroup.cpp b/src/dawn/native/BindGroup.cpp index 4e2a217..65ff312 100644 --- a/src/dawn/native/BindGroup.cpp +++ b/src/dawn/native/BindGroup.cpp
@@ -305,11 +305,17 @@ DAWN_TRY(device->ValidateObject(descriptor->layout)); BindGroupLayoutInternalBase* layout = descriptor->layout->GetInternalBindGroupLayout(); + + // NOTE: Static sampler layout bindings should not have bind group entries, + // as the sampler is specified in the layout itself. + const auto expectedBindingsCount = + layout->GetUnexpandedBindingCount() - layout->GetStaticSamplerCount(); + DAWN_INVALID_IF( - descriptor->entryCount != layout->GetUnexpandedBindingCount(), - "Number of entries (%u) did not match the number of entries (%u) specified in %s." + descriptor->entryCount != expectedBindingsCount, + "Number of entries (%u) did not match the expected number of entries (%u) for %s." "\nExpected layout: %s", - descriptor->entryCount, static_cast<uint32_t>(layout->GetBindingCount()), layout, + descriptor->entryCount, static_cast<uint32_t>(expectedBindingsCount), layout, layout->EntriesToString()); const BindGroupLayoutInternalBase::BindingMap& bindingMap = layout->GetBindingMap(); @@ -389,6 +395,12 @@ "\nExpected entry layout: %s", i, layout); return {}; + }, + [&](const StaticSamplerHolderBindingLayout& layout) -> MaybeError { + return DAWN_VALIDATION_ERROR( + "entries[%u] is provided when the layout contains a static sampler for that " + "binding.", + i); })); } @@ -397,7 +409,7 @@ // - Each binding must be set at most once // // We don't validate the equality because it wouldn't be possible to cover it with a test. - DAWN_ASSERT(bindingsSet.count() == layout->GetUnexpandedBindingCount()); + DAWN_ASSERT(bindingsSet.count() == expectedBindingsCount); return {}; }
diff --git a/src/dawn/native/BindGroupLayoutInternal.cpp b/src/dawn/native/BindGroupLayoutInternal.cpp index 0755617..d20bc05 100644 --- a/src/dawn/native/BindGroupLayoutInternal.cpp +++ b/src/dawn/native/BindGroupLayoutInternal.cpp
@@ -356,6 +356,11 @@ const SamplerBindingLayout& layoutB = std::get<SamplerBindingLayout>(b.bindingLayout); return layoutA.type != layoutB.type; }, + [&](const StaticSamplerHolderBindingLayout& layoutA) -> bool { + const StaticSamplerHolderBindingLayout& layoutB = + std::get<StaticSamplerHolderBindingLayout>(b.bindingLayout); + return layoutA.sampler != layoutB.sampler; + }, [&](const TextureBindingLayout& layoutA) -> bool { const TextureBindingLayout& layoutB = std::get<TextureBindingLayout>(b.bindingLayout); return layoutA.sampleType != layoutB.sampleType || @@ -405,7 +410,9 @@ } bindingInfo.bindingLayout = bindingLayout; } else if (auto* staticSamplerBindingLayout = binding.Get<StaticSamplerBindingLayout>()) { - // TODO(crbug.com/dawn/2463): Populate BindingInfo for this entry. + StaticSamplerHolderBindingLayout bindingLayout; + bindingLayout.sampler = staticSamplerBindingLayout->sampler; + bindingInfo.bindingLayout = bindingLayout; } else { DAWN_UNREACHABLE(); } @@ -504,6 +511,14 @@ } break; } + case BindingInfoType::StaticSampler: { + const auto& aLayout = std::get<StaticSamplerHolderBindingLayout>(aInfo.bindingLayout); + const auto& bLayout = std::get<StaticSamplerHolderBindingLayout>(bInfo.bindingLayout); + if (aLayout.sampler != bLayout.sampler) { + return aLayout.sampler < bLayout.sampler; + } + break; + } case BindingInfoType::ExternalTexture: DAWN_UNREACHABLE(); break; @@ -632,6 +647,9 @@ [&](const StorageTextureBindingLayout& layout) { recorder.Record(BindingInfoType::StorageTexture, layout.access, layout.format, layout.viewDimension); + }, + [&](const StaticSamplerHolderBindingLayout& layout) { + recorder.Record(BindingInfoType::StaticSampler, layout.sampler->GetContentHash()); }); } @@ -663,6 +681,10 @@ return mBindingCounts.unverifiedBufferCount; } +uint32_t BindGroupLayoutInternalBase::GetStaticSamplerCount() const { + return mBindingCounts.staticSamplerCount; +} + uint32_t BindGroupLayoutInternalBase::GetExternalTextureBindingCount() const { return mExternalTextureBindingExpansionMap.size(); }
diff --git a/src/dawn/native/BindGroupLayoutInternal.h b/src/dawn/native/BindGroupLayoutInternal.h index d41c983..43fddc9 100644 --- a/src/dawn/native/BindGroupLayoutInternal.h +++ b/src/dawn/native/BindGroupLayoutInternal.h
@@ -97,6 +97,7 @@ // Returns |BindingIndex| because dynamic buffers are packed at the front. BindingIndex GetDynamicBufferCount() const; uint32_t GetUnverifiedBufferCount() const; + uint32_t GetStaticSamplerCount() const; // Used to get counts and validate them in pipeline layout creation. Other getters // should be used to get typed integer counts.
diff --git a/src/dawn/native/BindingInfo.cpp b/src/dawn/native/BindingInfo.cpp index 879104e..32b3468 100644 --- a/src/dawn/native/BindingInfo.cpp +++ b/src/dawn/native/BindingInfo.cpp
@@ -41,6 +41,9 @@ [](const TextureBindingLayout&) -> BindingInfoType { return BindingInfoType::Texture; }, [](const StorageTextureBindingLayout&) -> BindingInfoType { return BindingInfoType::StorageTexture; + }, + [](const StaticSamplerHolderBindingLayout&) -> BindingInfoType { + return BindingInfoType::StaticSampler; }); } @@ -89,6 +92,7 @@ } else if (auto* externalTextureBindingLayout = entry.Get<ExternalTextureBindingLayout>()) { perStageBindingCountMember = &PerStageBindingCounts::externalTextureCount; } else if (auto* staticSamplerBindingLayout = entry.Get<StaticSamplerBindingLayout>()) { + ++bindingCounts->staticSamplerCount; perStageBindingCountMember = &PerStageBindingCounts::staticSamplerCount; }
diff --git a/src/dawn/native/BindingInfo.h b/src/dawn/native/BindingInfo.h index 1eba544..638d2f2 100644 --- a/src/dawn/native/BindingInfo.h +++ b/src/dawn/native/BindingInfo.h
@@ -33,6 +33,7 @@ #include <vector> #include "dawn/common/Constants.h" +#include "dawn/common/Ref.h" #include "dawn/common/ityp_array.h" #include "dawn/native/Error.h" #include "dawn/native/Format.h" @@ -56,7 +57,18 @@ // TODO(enga): Figure out a good number for this. static constexpr uint32_t kMaxOptimalBindingsPerGroup = 32; -enum class BindingInfoType { Buffer, Sampler, Texture, StorageTexture, ExternalTexture }; +enum class BindingInfoType { + Buffer, + Sampler, + Texture, + StorageTexture, + ExternalTexture, + StaticSampler +}; + +struct StaticSamplerHolderBindingLayout { + Ref<SamplerBase> sampler; +}; struct BindingInfo { BindingNumber binding; @@ -65,7 +77,8 @@ std::variant<BufferBindingLayout, SamplerBindingLayout, TextureBindingLayout, - StorageTextureBindingLayout> + StorageTextureBindingLayout, + StaticSamplerHolderBindingLayout> bindingLayout; }; @@ -92,6 +105,7 @@ uint32_t unverifiedBufferCount; // Buffers with minimum buffer size unspecified uint32_t dynamicUniformBufferCount; uint32_t dynamicStorageBufferCount; + uint32_t staticSamplerCount; PerStage<PerStageBindingCounts> perStage; };
diff --git a/src/dawn/native/PassResourceUsageTracker.cpp b/src/dawn/native/PassResourceUsageTracker.cpp index 8a9e8e4..96dff82 100644 --- a/src/dawn/native/PassResourceUsageTracker.cpp +++ b/src/dawn/native/PassResourceUsageTracker.cpp
@@ -160,7 +160,8 @@ DAWN_UNREACHABLE(); } }, - [&](const SamplerBindingLayout&) {}); + [&](const SamplerBindingLayout&) {}, // + [&](const StaticSamplerHolderBindingLayout&) {}); } for (const Ref<ExternalTextureBase>& externalTexture : group->GetBoundExternalTextures()) { @@ -225,7 +226,7 @@ mUsage.referencedTextures.insert( group->GetBindingAsTextureView(index)->GetTexture()); }, - [](const SamplerBindingLayout&) {}); + [](const SamplerBindingLayout&) {}, [](const StaticSamplerHolderBindingLayout&) {}); } for (const Ref<ExternalTextureBase>& externalTexture : group->GetBoundExternalTextures()) {
diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp index 4e24341..fc10ab3 100644 --- a/src/dawn/native/ShaderModule.cpp +++ b/src/dawn/native/ShaderModule.cpp
@@ -468,6 +468,15 @@ BindingInfoType bindingLayoutType = GetBindingInfoType(layoutInfo); BindingInfoType shaderBindingType = GetShaderBindingType(shaderInfo); + + if (bindingLayoutType == BindingInfoType::StaticSampler) { + DAWN_INVALID_IF(shaderBindingType != BindingInfoType::Sampler, + "Binding type in the shader (%s) doesn't match the required type of %s for " + "the %s type in the layout.", + shaderBindingType, BindingInfoType::Sampler, bindingLayoutType); + return {}; + } + DAWN_INVALID_IF(bindingLayoutType != shaderBindingType, "Binding type in the shader (%s) doesn't match the type in the layout (%s).", shaderBindingType, bindingLayoutType); @@ -902,6 +911,9 @@ info.bindingInfo.emplace<ExternalTextureBindingInfo>(); break; } + case BindingInfoType::StaticSampler: { + return DAWN_VALIDATION_ERROR("Static samplers not supported in WGSL"); + } default: return DAWN_VALIDATION_ERROR("Unknown binding type in Shader"); }
diff --git a/src/dawn/native/d3d11/BindGroupTrackerD3D11.cpp b/src/dawn/native/d3d11/BindGroupTrackerD3D11.cpp index df0c29f..46954c4 100644 --- a/src/dawn/native/d3d11/BindGroupTrackerD3D11.cpp +++ b/src/dawn/native/d3d11/BindGroupTrackerD3D11.cpp
@@ -228,7 +228,13 @@ return {}; }, [](const TextureBindingLayout&) -> MaybeError { return {}; }, - [](const SamplerBindingLayout&) -> MaybeError { return {}; })); + [](const SamplerBindingLayout&) -> MaybeError { return {}; }, + [](const StaticSamplerHolderBindingLayout&) -> MaybeError { + // Static samplers are implemented in the frontend on + // D3D11. + DAWN_UNREACHABLE(); + return {}; + })); } } @@ -366,6 +372,12 @@ } return {}; }, + [&](const StaticSamplerHolderBindingLayout&) -> MaybeError { + // Static samplers are implemented in the frontend on + // D3D11. + DAWN_UNREACHABLE(); + return {}; + }, [&](const SamplerBindingLayout&) -> MaybeError { Sampler* sampler = ToBackend(group->GetBindingAsSampler(bindingIndex)); ID3D11SamplerState* d3d11SamplerState = sampler->GetD3D11SamplerState(); @@ -503,6 +515,11 @@ DAWN_UNREACHABLE(); } }, + [&](const StaticSamplerHolderBindingLayout&) { + // Static samplers are implemented in the frontend on + // D3D11. + DAWN_UNREACHABLE(); + }, [&](const SamplerBindingLayout&) { ID3D11SamplerState* nullSampler = nullptr; if (bindingVisibility & wgpu::ShaderStage::Vertex) {
diff --git a/src/dawn/native/d3d11/PipelineLayoutD3D11.cpp b/src/dawn/native/d3d11/PipelineLayoutD3D11.cpp index b0915be..f3d7400 100644 --- a/src/dawn/native/d3d11/PipelineLayoutD3D11.cpp +++ b/src/dawn/native/d3d11/PipelineLayoutD3D11.cpp
@@ -84,6 +84,11 @@ [&](const SamplerBindingLayout&) { mIndexInfo[group][bindingIndex] = samplerIndex++; }, + [&](const StaticSamplerHolderBindingLayout&) { + // Static samplers are implemented in the frontend on + // D3D11. + DAWN_UNREACHABLE(); + }, [&](const TextureBindingLayout&) { mIndexInfo[group][bindingIndex] = shaderResourceViewIndex++; },
diff --git a/src/dawn/native/d3d12/BindGroupD3D12.cpp b/src/dawn/native/d3d12/BindGroupD3D12.cpp index 7298e6b..2698840 100644 --- a/src/dawn/native/d3d12/BindGroupD3D12.cpp +++ b/src/dawn/native/d3d12/BindGroupD3D12.cpp
@@ -194,6 +194,12 @@ DAWN_UNREACHABLE(); } }, + [](const StaticSamplerHolderBindingLayout&) { + // Static samplers are handled in the frontend. + // TODO(crbug.com/dawn/2483): Implement static samplers in the + // D3D12 backend. + DAWN_UNREACHABLE(); + }, // No-op as samplers will be later initialized by CreateSamplers(). [](const SamplerBindingLayout&) {}); }
diff --git a/src/dawn/native/d3d12/BindGroupLayoutD3D12.cpp b/src/dawn/native/d3d12/BindGroupLayoutD3D12.cpp index ed02fde..7a8d088 100644 --- a/src/dawn/native/d3d12/BindGroupLayoutD3D12.cpp +++ b/src/dawn/native/d3d12/BindGroupLayoutD3D12.cpp
@@ -53,6 +53,13 @@ DAWN_UNREACHABLE(); } }, + [](const StaticSamplerHolderBindingLayout&) -> D3D12_DESCRIPTOR_RANGE_TYPE { + // Static samplers are handled in the frontend. + // TODO(crbug.com/dawn/2483): Implement static samplers in the + // D3D12 backend. + DAWN_UNREACHABLE(); + return D3D12_DESCRIPTOR_RANGE_TYPE_SAMPLER; + }, [](const SamplerBindingLayout&) -> D3D12_DESCRIPTOR_RANGE_TYPE { return D3D12_DESCRIPTOR_RANGE_TYPE_SAMPLER; }, @@ -120,6 +127,13 @@ // don't need to set DESCRIPTORS_VOLATILE for any binding types. range.Flags = MatchVariant( bindingInfo.bindingLayout, + [](const StaticSamplerHolderBindingLayout&) -> D3D12_DESCRIPTOR_RANGE_FLAGS { + // Static samplers are handled in the frontend. + // TODO(crbug.com/dawn/2483): Implement static samplers in the + // D3D12 backend. + DAWN_UNREACHABLE(); + return D3D12_DESCRIPTOR_RANGE_FLAG_NONE; + }, [](const SamplerBindingLayout&) -> D3D12_DESCRIPTOR_RANGE_FLAGS { // Sampler descriptor ranges don't support DATA_* flags at all since samplers do not // point to data.
diff --git a/src/dawn/native/metal/CommandBufferMTL.mm b/src/dawn/native/metal/CommandBufferMTL.mm index ce703a7..fe3d2c4 100644 --- a/src/dawn/native/metal/CommandBufferMTL.mm +++ b/src/dawn/native/metal/CommandBufferMTL.mm
@@ -624,6 +624,12 @@ atIndex:computeIndex]; } }, + [&](const StaticSamplerHolderBindingLayout&) { + // Static samplers are handled in the frontend. + // TODO(crbug.com/dawn/2482): Implement static samplers in the + // Metal backend. + DAWN_UNREACHABLE(); + }, [&](const TextureBindingLayout&) { auto textureView = ToBackend(group->GetBindingAsTextureView(bindingIndex)); if (hasVertStage) {
diff --git a/src/dawn/native/metal/PipelineLayoutMTL.mm b/src/dawn/native/metal/PipelineLayoutMTL.mm index 6bfc9dd..0eca24f 100644 --- a/src/dawn/native/metal/PipelineLayoutMTL.mm +++ b/src/dawn/native/metal/PipelineLayoutMTL.mm
@@ -78,6 +78,12 @@ [&](const StorageTextureBindingLayout&) { mIndexInfo[stage][group][bindingIndex] = textureIndex; textureIndex++; + }, + [&](const StaticSamplerHolderBindingLayout&) { + // Static samplers are handled in the frontend. + // TODO(crbug.com/dawn/2482): Implement static samplers in the + // Metal backend. + DAWN_UNREACHABLE(); }); } }
diff --git a/src/dawn/native/opengl/CommandBufferGL.cpp b/src/dawn/native/opengl/CommandBufferGL.cpp index a129dd7..321e0e1 100644 --- a/src/dawn/native/opengl/CommandBufferGL.cpp +++ b/src/dawn/native/opengl/CommandBufferGL.cpp
@@ -308,6 +308,11 @@ gl.BindBufferRange(target, index, buffer, offset, binding.size); }, + [&](const StaticSamplerHolderBindingLayout&) { + // Static samplers are implemented in the frontend on + // GL. + DAWN_UNREACHABLE(); + }, [&](const SamplerBindingLayout&) { Sampler* sampler = ToBackend(group->GetBindingAsSampler(bindingIndex)); GLuint samplerIndex = indices[bindingIndex];
diff --git a/src/dawn/native/opengl/PipelineLayoutGL.cpp b/src/dawn/native/opengl/PipelineLayoutGL.cpp index 4b44aa6..b4dc75b 100644 --- a/src/dawn/native/opengl/PipelineLayoutGL.cpp +++ b/src/dawn/native/opengl/PipelineLayoutGL.cpp
@@ -67,6 +67,11 @@ DAWN_UNREACHABLE(); } }, + [&](const StaticSamplerHolderBindingLayout&) { + // Static samplers are implemented in the frontend on + // GL. + DAWN_UNREACHABLE(); + }, [&](const SamplerBindingLayout&) { mIndexInfo[group][bindingIndex] = samplerIndex; samplerIndex++;
diff --git a/src/dawn/native/vulkan/BindGroupLayoutVk.cpp b/src/dawn/native/vulkan/BindGroupLayoutVk.cpp index 70e90bc..7c73c29 100644 --- a/src/dawn/native/vulkan/BindGroupLayoutVk.cpp +++ b/src/dawn/native/vulkan/BindGroupLayoutVk.cpp
@@ -85,6 +85,13 @@ } }, [](const SamplerBindingLayout&) { return VK_DESCRIPTOR_TYPE_SAMPLER; }, + [](const StaticSamplerHolderBindingLayout&) { + // Static samplers are implemented in the frontend. + // TODO(crbug.com/dawn/2463): Implement static samplers in the backend + // on Vulkan. + DAWN_UNREACHABLE(); + return VK_DESCRIPTOR_TYPE_SAMPLER; + }, [](const TextureBindingLayout&) { return VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE; }, [](const StorageTextureBindingLayout&) { return VK_DESCRIPTOR_TYPE_STORAGE_IMAGE; }); }
diff --git a/src/dawn/native/vulkan/BindGroupVk.cpp b/src/dawn/native/vulkan/BindGroupVk.cpp index fa40523..e0ee26d 100644 --- a/src/dawn/native/vulkan/BindGroupVk.cpp +++ b/src/dawn/native/vulkan/BindGroupVk.cpp
@@ -104,6 +104,13 @@ write.pImageInfo = &writeImageInfo[numWrites]; return true; }, + [&](const StaticSamplerHolderBindingLayout&) -> bool { + // Static samplers are implemented in the frontend. + // TODO(crbug.com/dawn/2463): Implement static samplers in the backend + // on Vulkan. + DAWN_UNREACHABLE(); + return true; + }, [&](const TextureBindingLayout&) -> bool { TextureView* view = ToBackend(GetBindingAsTextureView(bindingIndex));
diff --git a/src/dawn/native/webgpu_absl_format.cpp b/src/dawn/native/webgpu_absl_format.cpp index e41b2f7..ef1adcf 100644 --- a/src/dawn/native/webgpu_absl_format.cpp +++ b/src/dawn/native/webgpu_absl_format.cpp
@@ -40,6 +40,7 @@ #include "dawn/native/PerStage.h" #include "dawn/native/ProgrammableEncoder.h" #include "dawn/native/RenderPipeline.h" +#include "dawn/native/Sampler.h" #include "dawn/native/ShaderModule.h" #include "dawn/native/Subresource.h" #include "dawn/native/Surface.h" @@ -127,6 +128,10 @@ s->Append(absl::StrFormat(*fmt, static_cast<uint32_t>(value.binding), value.visibility, BindingInfoType::Sampler, layout)); }, + [&](const StaticSamplerHolderBindingLayout& layout) { + s->Append(absl::StrFormat(*fmt, static_cast<uint32_t>(value.binding), value.visibility, + BindingInfoType::StaticSampler, layout)); + }, [&](const TextureBindingLayout& layout) { s->Append(absl::StrFormat(*fmt, static_cast<uint32_t>(value.binding), value.visibility, BindingInfoType::Texture, layout)); @@ -181,6 +186,15 @@ return {true}; } +absl::FormatConvertResult<absl::FormatConversionCharSet::kString> AbslFormatConvert( + const StaticSamplerHolderBindingLayout& value, + const absl::FormatConversionSpec& spec, + absl::FormatSink* s) { + s->Append( + absl::StrFormat("{type: StaticSamplerBindingLayout, sampler: %s}", value.sampler.Get())); + return {true}; +} + // // Objects // @@ -448,6 +462,9 @@ case BindingInfoType::ExternalTexture: s->Append("externalTexture"); break; + case BindingInfoType::StaticSampler: + s->Append("staticSampler"); + break; } return {true}; }
diff --git a/src/dawn/native/webgpu_absl_format.h b/src/dawn/native/webgpu_absl_format.h index dceb3bc..79879bb 100644 --- a/src/dawn/native/webgpu_absl_format.h +++ b/src/dawn/native/webgpu_absl_format.h
@@ -100,6 +100,12 @@ const absl::FormatConversionSpec& spec, absl::FormatSink* s); +struct StaticSamplerHolderBindingLayout; +absl::FormatConvertResult<absl::FormatConversionCharSet::kString> AbslFormatConvert( + const StaticSamplerHolderBindingLayout& value, + const absl::FormatConversionSpec& spec, + absl::FormatSink* s); + // // Objects //
diff --git a/src/dawn/tests/unittests/validation/BindGroupValidationTests.cpp b/src/dawn/tests/unittests/validation/BindGroupValidationTests.cpp index c856614..999b4b0 100644 --- a/src/dawn/tests/unittests/validation/BindGroupValidationTests.cpp +++ b/src/dawn/tests/unittests/validation/BindGroupValidationTests.cpp
@@ -1818,10 +1818,182 @@ ASSERT_DEVICE_ERROR(device.CreateBindGroupLayout(&desc)); } +// Verifies that creation of a bind group with no entry for a static sampler +// succeeds. +TEST_F(BindGroupLayoutWithStaticSamplersValidationTest, CreateBindGroupWithStaticSamplerSupported) { + wgpu::BindGroupLayoutEntry binding = {}; + binding.binding = 0; + wgpu::StaticSamplerBindingLayout staticSamplerBinding = {}; + staticSamplerBinding.sampler = device.CreateSampler(); + binding.nextInChain = &staticSamplerBinding; + + wgpu::BindGroupLayoutDescriptor desc = {}; + desc.entryCount = 1; + desc.entries = &binding; + + wgpu::BindGroupLayout layout = device.CreateBindGroupLayout(&desc); + + wgpu::BindGroupDescriptor descriptor; + descriptor.layout = layout; + descriptor.entryCount = 0; + + device.CreateBindGroup(&descriptor); +} + +// Verifies that creation of a correctly-specified bind group for a layout that +// has a static sampler and a sampler succeeds. +TEST_F(BindGroupLayoutWithStaticSamplersValidationTest, + CreateBindGroupWithStaticSamplerAndSamplerSupported) { + std::vector<wgpu::BindGroupLayoutEntry> entries; + + wgpu::BindGroupLayoutEntry binding0 = {}; + binding0.binding = 0; + wgpu::StaticSamplerBindingLayout staticSamplerBinding = {}; + staticSamplerBinding.sampler = device.CreateSampler(); + binding0.nextInChain = &staticSamplerBinding; + entries.push_back(binding0); + + wgpu::BindGroupLayoutEntry binding1 = {}; + binding1.binding = 1; + binding1.sampler.type = wgpu::SamplerBindingType::Filtering; + entries.push_back(binding1); + + wgpu::BindGroupLayoutDescriptor desc = {}; + desc.entryCount = 2; + desc.entries = entries.data(); + + wgpu::BindGroupLayout layout = device.CreateBindGroupLayout(&desc); + + wgpu::SamplerDescriptor samplerDesc; + samplerDesc.minFilter = wgpu::FilterMode::Linear; + utils::MakeBindGroup(device, layout, {{1, device.CreateSampler(&samplerDesc)}}); +} + +// Verifies that creation of a correctly-specified bind group for a layout that +// has a sampler and a static sampler succeeds. +TEST_F(BindGroupLayoutWithStaticSamplersValidationTest, + CreateBindGroupWithSamplerAndStaticSamplerSupported) { + std::vector<wgpu::BindGroupLayoutEntry> entries; + + wgpu::BindGroupLayoutEntry binding0 = {}; + binding0.binding = 0; + binding0.sampler.type = wgpu::SamplerBindingType::Filtering; + entries.push_back(binding0); + + wgpu::BindGroupLayoutEntry binding1 = {}; + binding1.binding = 1; + wgpu::StaticSamplerBindingLayout staticSamplerBinding = {}; + staticSamplerBinding.sampler = device.CreateSampler(); + binding1.nextInChain = &staticSamplerBinding; + entries.push_back(binding1); + + wgpu::BindGroupLayoutDescriptor desc = {}; + desc.entryCount = 2; + desc.entries = entries.data(); + + wgpu::BindGroupLayout layout = device.CreateBindGroupLayout(&desc); + + wgpu::SamplerDescriptor samplerDesc; + samplerDesc.minFilter = wgpu::FilterMode::Linear; + utils::MakeBindGroup(device, layout, {{0, device.CreateSampler(&samplerDesc)}}); +} + +// Verifies that creating a bind group with an entry for a static sampler causes +// an error. +TEST_F(BindGroupLayoutWithStaticSamplersValidationTest, + EntryForStaticSamplerInBindGroupCausesError) { + wgpu::BindGroupLayoutEntry binding = {}; + binding.binding = 0; + wgpu::StaticSamplerBindingLayout staticSamplerBinding = {}; + staticSamplerBinding.sampler = device.CreateSampler(); + binding.nextInChain = &staticSamplerBinding; + + wgpu::BindGroupLayoutDescriptor desc = {}; + desc.entryCount = 1; + desc.entries = &binding; + + wgpu::BindGroupLayout layout = device.CreateBindGroupLayout(&desc); + + wgpu::SamplerDescriptor samplerDesc; + samplerDesc.minFilter = wgpu::FilterMode::Linear; + ASSERT_DEVICE_ERROR( + utils::MakeBindGroup(device, layout, {{0, device.CreateSampler(&samplerDesc)}})); +} + +// Verifies that creation of a bind group with the correct number of entries for a layout that has a +// sampler and a static sampler raises an error if the entry is specified at the +// index of the static sampler rather than that of the sampler. +TEST_F(BindGroupLayoutWithStaticSamplersValidationTest, + CorrectNumberOfEntriesButEntryForStaticSamplerAtSecondIndexInBindGroupCausesError) { + std::vector<wgpu::BindGroupLayoutEntry> entries; + + wgpu::BindGroupLayoutEntry binding0 = {}; + binding0.binding = 0; + binding0.sampler.type = wgpu::SamplerBindingType::Filtering; + entries.push_back(binding0); + + wgpu::BindGroupLayoutEntry binding1 = {}; + binding1.binding = 1; + wgpu::StaticSamplerBindingLayout staticSamplerBinding = {}; + staticSamplerBinding.sampler = device.CreateSampler(); + binding1.nextInChain = &staticSamplerBinding; + entries.push_back(binding1); + + wgpu::BindGroupLayoutDescriptor desc = {}; + desc.entryCount = 2; + desc.entries = entries.data(); + + wgpu::BindGroupLayout layout = device.CreateBindGroupLayout(&desc); + + wgpu::SamplerDescriptor samplerDesc; + samplerDesc.minFilter = wgpu::FilterMode::Linear; + ASSERT_DEVICE_ERROR( + utils::MakeBindGroup(device, layout, {{1, device.CreateSampler(&samplerDesc)}})); +} + +// Verifies that creation of a bind group with the correct number of entries for a layout that has a +// static sampler and a sampler raises an error if the entry is specified at the +// index of the static sampler rather than that of the sampler. +TEST_F(BindGroupLayoutWithStaticSamplersValidationTest, + CorrectNumberOfEntriesButEntryForStaticSamplerInBindGroupCausesError) { + std::vector<wgpu::BindGroupLayoutEntry> entries; + + wgpu::BindGroupLayoutEntry binding0 = {}; + binding0.binding = 0; + wgpu::StaticSamplerBindingLayout staticSamplerBinding = {}; + staticSamplerBinding.sampler = device.CreateSampler(); + binding0.nextInChain = &staticSamplerBinding; + entries.push_back(binding0); + + wgpu::BindGroupLayoutEntry binding1 = {}; + binding1.binding = 1; + binding1.sampler.type = wgpu::SamplerBindingType::Filtering; + entries.push_back(binding1); + + wgpu::BindGroupLayoutDescriptor desc = {}; + desc.entryCount = 2; + desc.entries = entries.data(); + + wgpu::BindGroupLayout layout = device.CreateBindGroupLayout(&desc); + + wgpu::SamplerDescriptor samplerDesc; + samplerDesc.minFilter = wgpu::FilterMode::Linear; + ASSERT_DEVICE_ERROR( + utils::MakeBindGroup(device, layout, {{0, device.CreateSampler(&samplerDesc)}})); +} + constexpr uint32_t kBindingSize = 8; class SetBindGroupValidationTest : public ValidationTest { public: + WGPUDevice CreateTestDevice(native::Adapter dawnAdapter, + wgpu::DeviceDescriptor descriptor) override { + wgpu::FeatureName requiredFeatures[1] = {wgpu::FeatureName::StaticSamplers}; + descriptor.requiredFeatures = requiredFeatures; + descriptor.requiredFeatureCount = 1; + return dawnAdapter.CreateDevice(&descriptor); + } + void SetUp() override { ValidationTest::SetUp(); @@ -2457,6 +2629,63 @@ } } +// Test that a static sampler is valid to access from a shader module. +TEST_F(SetBindGroupValidationTest, StaticSamplerAccessedFromShader) { + wgpu::BindGroupLayoutEntry binding = {}; + binding.binding = 0; + binding.visibility = wgpu::ShaderStage::Compute; + wgpu::StaticSamplerBindingLayout staticSamplerBinding = {}; + staticSamplerBinding.sampler = device.CreateSampler(); + binding.nextInChain = &staticSamplerBinding; + + wgpu::BindGroupLayoutDescriptor desc = {}; + desc.entryCount = 1; + desc.entries = &binding; + + wgpu::BindGroupLayout layout = device.CreateBindGroupLayout(&desc); + + wgpu::PipelineLayout pl = utils::MakePipelineLayout(device, {layout}); + + wgpu::ComputePipelineDescriptor pipelineDesc; + pipelineDesc.layout = pl; + pipelineDesc.compute.module = utils::CreateShaderModule(device, R"( + @group(0) @binding(0) var s : sampler; + @compute @workgroup_size(1) fn main() { + _ = s; + } + )"); + device.CreateComputePipeline(&pipelineDesc); +} + +// Test that a static sampler cannot be accessed from a shader module as a +// non-sampler variable type. +TEST_F(SetBindGroupValidationTest, StaticSamplerAccessedFromShaderAsNonSamplerTypeCausesError) { + wgpu::BindGroupLayoutEntry binding = {}; + binding.binding = 0; + binding.visibility = wgpu::ShaderStage::Compute; + wgpu::StaticSamplerBindingLayout staticSamplerBinding = {}; + staticSamplerBinding.sampler = device.CreateSampler(); + binding.nextInChain = &staticSamplerBinding; + + wgpu::BindGroupLayoutDescriptor desc = {}; + desc.entryCount = 1; + desc.entries = &binding; + + wgpu::BindGroupLayout layout = device.CreateBindGroupLayout(&desc); + + wgpu::PipelineLayout pl = utils::MakePipelineLayout(device, {layout}); + + wgpu::ComputePipelineDescriptor pipelineDesc; + pipelineDesc.layout = pl; + pipelineDesc.compute.module = utils::CreateShaderModule(device, R"( + @group(0) @binding(0) var t : texture_2d<f32>; + @compute @workgroup_size(1) fn main() { + _ = textureDimensions(t); + } + )"); + ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&pipelineDesc)); +} + class SetBindGroupPersistenceValidationTest : public ValidationTest { protected: void SetUp() override {
diff --git a/src/dawn/tests/unittests/validation/ObjectCachingTests.cpp b/src/dawn/tests/unittests/validation/ObjectCachingTests.cpp index f412ab0..80280b7 100644 --- a/src/dawn/tests/unittests/validation/ObjectCachingTests.cpp +++ b/src/dawn/tests/unittests/validation/ObjectCachingTests.cpp
@@ -40,6 +40,14 @@ // These tests works assuming Dawn Native's object deduplication. Comparing the pointer is // exploiting an implementation detail of Dawn Native. class ObjectCachingTest : public ValidationTest { + WGPUDevice CreateTestDevice(native::Adapter dawnAdapter, + wgpu::DeviceDescriptor descriptor) override { + wgpu::FeatureName requiredFeatures[1] = {wgpu::FeatureName::StaticSamplers}; + descriptor.requiredFeatures = requiredFeatures; + descriptor.requiredFeatureCount = 1; + return dawnAdapter.CreateDevice(&descriptor); + } + void SetUp() override { ValidationTest::SetUp(); DAWN_SKIP_TEST_IF(UsesWire()); @@ -114,6 +122,77 @@ EXPECT_THAT(bgl, BindGroupLayoutEq(sameBgl)); } +// Test that BindGroupLayouts with a static sampler entry are correctly +// deduplicated. +TEST_F(ObjectCachingTest, BindGroupLayoutStaticSamplerDeduplication) { + // TODO(crbug.com/dawn/2489): The inequality check between bind group + // layouts with distinct static samplers fails on the MSVC bots. + DAWN_SKIP_TEST_IF(DAWN_PLATFORM_IS(WINDOWS)); + + wgpu::BindGroupLayoutEntry binding = {}; + binding.binding = 0; + wgpu::StaticSamplerBindingLayout staticSamplerBinding = {}; + wgpu::SamplerDescriptor samplerDesc; + samplerDesc.addressModeU = wgpu::AddressMode::ClampToEdge; + staticSamplerBinding.sampler = device.CreateSampler(&samplerDesc); + binding.nextInChain = &staticSamplerBinding; + + wgpu::BindGroupLayoutDescriptor desc = {}; + desc.entryCount = 1; + desc.entries = &binding; + + wgpu::SamplerDescriptor otherSamplerDesc; + otherSamplerDesc.addressModeU = wgpu::AddressMode::Repeat; + wgpu::Sampler otherSampler = device.CreateSampler(&otherSamplerDesc); + + EXPECT_NE(staticSamplerBinding.sampler.Get(), otherSampler.Get()); + + wgpu::BindGroupLayoutEntry otherBinding = {}; + otherBinding.binding = 0; + wgpu::StaticSamplerBindingLayout otherStaticSamplerBinding = {}; + otherStaticSamplerBinding.sampler = otherSampler; + otherBinding.nextInChain = &otherStaticSamplerBinding; + + EXPECT_NE(staticSamplerBinding.sampler.Get(), otherStaticSamplerBinding.sampler.Get()); + + wgpu::BindGroupLayoutDescriptor otherDesc = {}; + otherDesc.entryCount = 1; + otherDesc.entries = &otherBinding; + + wgpu::BindGroupLayout bgl = device.CreateBindGroupLayout(&desc); + wgpu::BindGroupLayout sameBgl = device.CreateBindGroupLayout(&desc); + wgpu::BindGroupLayout otherStaticSamplerBgl = device.CreateBindGroupLayout(&otherDesc); + wgpu::BindGroupLayout otherBgl = utils::MakeBindGroupLayout( + device, {{0, wgpu::ShaderStage::Fragment, wgpu::SamplerBindingType::Filtering}}); + + EXPECT_THAT(bgl, BindGroupLayoutEq(sameBgl)); + EXPECT_THAT(bgl, Not(BindGroupLayoutEq(otherStaticSamplerBgl))); + EXPECT_THAT(bgl, Not(BindGroupLayoutEq(otherBgl))); +} + +// Test that BindGroupLayouts with a static sampler entry keep a reference +// to the static sampler, such that if a sampler is created from the same +// params the same object will be returned. +TEST_F(ObjectCachingTest, BindGroupLayoutKeepsRefToStaticSampler) { + auto sampler1 = device.CreateSampler(); + wgpu::BindGroupLayoutEntry binding = {}; + binding.binding = 0; + wgpu::StaticSamplerBindingLayout staticSamplerBinding = {}; + staticSamplerBinding.sampler = sampler1; + binding.nextInChain = &staticSamplerBinding; + + wgpu::BindGroupLayoutDescriptor desc = {}; + desc.entryCount = 1; + desc.entries = &binding; + + wgpu::BindGroupLayout bgl = device.CreateBindGroupLayout(&desc); + + auto samplerRawPtr = sampler1.Get(); + sampler1 = nullptr; + auto sampler2 = device.CreateSampler(); + EXPECT_EQ(samplerRawPtr, sampler2.Get()); +} + // Test that PipelineLayouts are correctly deduplicated. TEST_F(ObjectCachingTest, PipelineLayoutDeduplication) { wgpu::BindGroupLayout bgl = utils::MakeBindGroupLayout(