Add static sampler BindGroup validation

Add validation for the following:
1. YCbCr static samplers always sample from YCbCr textures.
2. YCbCr textures are never sampled by non-YCbCr static samplers.
3. YCbCr textures are sampled by a YCbCr static sampler. They cannot be
   dynamically sampled.

This requires changes to expose YCbCrVkDescriptor on TextureView. Also
includes tests for the validation.

Bug: 42241425
Change-Id: I0d9b2f0b80c7e66b0e8e6f21cdc20f5c6ee468db
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/216115
Commit-Queue: Kyle Charbonneau <kylechar@google.com>
Reviewed-by: Loko Kung <lokokung@google.com>
diff --git a/src/dawn/native/BindGroup.cpp b/src/dawn/native/BindGroup.cpp
index f8f53f4..147a4b5 100644
--- a/src/dawn/native/BindGroup.cpp
+++ b/src/dawn/native/BindGroup.cpp
@@ -27,12 +27,17 @@
 
 #include "dawn/native/BindGroup.h"
 
+#include <variant>
+
+#include "absl/container/flat_hash_map.h"
 #include "dawn/common/Assert.h"
 #include "dawn/common/MatchVariant.h"
 #include "dawn/common/Math.h"
 #include "dawn/common/ityp_bitset.h"
 #include "dawn/native/Adapter.h"
 #include "dawn/native/BindGroupLayout.h"
+#include "dawn/native/BindGroupLayoutInternal.h"
+#include "dawn/native/BindingInfo.h"
 #include "dawn/native/Buffer.h"
 #include "dawn/native/ChainUtils.h"
 #include "dawn/native/CommandValidation.h"
@@ -345,6 +350,58 @@
     }
 }
 
+MaybeError ValidateStaticSamplersWithSampledTextures(const BindGroupDescriptor* descriptor,
+                                                     const BindGroupLayoutInternalBase* layout) {
+    absl::flat_hash_map<BindingNumber, uint32_t> bindingNumberToEntryIndexMap;
+    for (uint32_t i = 0; i < descriptor->entryCount; ++i) {
+        bindingNumberToEntryIndexMap[BindingNumber(descriptor->entries[i].binding)] = i;
+    }
+
+    // Entry indices of YCbCr textures sampled by a static sampler.
+    ityp::bitset<uint32_t, kMaxBindingsPerPipelineLayout> sampledYcbcrTextures;
+    for (BindingIndex index{0}; index < layout->GetBindingCount(); ++index) {
+        const BindingInfo& bindingInfo = layout->GetBindingInfo(index);
+        auto* staticSamplerLayout =
+            std::get_if<StaticSamplerBindingInfo>(&bindingInfo.bindingLayout);
+        if (staticSamplerLayout && staticSamplerLayout->isUsedForSingleTextureBinding) {
+            const SamplerBase* sampler = staticSamplerLayout->sampler.Get();
+
+            uint32_t textureEntryIndex = bindingNumberToEntryIndexMap.at(
+                BindingNumber(staticSamplerLayout->sampledTextureBinding));
+            const TextureViewBase* textureView = descriptor->entries[textureEntryIndex].textureView;
+
+            // Compare static sampler and sampled textures to make sure they are compatible.
+            if (sampler->IsYCbCr()) {
+                DAWN_INVALID_IF(!textureView->IsYCbCr(),
+                                "YCbCr static sampler at binding (%u) samples a non-YCbCr texture.",
+                                bindingInfo.binding);
+
+                sampledYcbcrTextures.set(textureEntryIndex);
+            } else {
+                DAWN_INVALID_IF(textureView->IsYCbCr(),
+                                "Non-YCbCr static sampler at binding (%u) samples a YCbCr texture.",
+                                bindingInfo.binding);
+            }
+        }
+    }
+
+    // Validate that all YCbCr texture entries are sampled by a static sampler.
+    const auto& bindingMap = layout->GetBindingMap();
+    for (uint32_t i = 0; i < descriptor->entryCount; ++i) {
+        const BindGroupEntry& entry = descriptor->entries[i];
+        const BindingInfo& bindingInfo =
+            layout->GetBindingInfo(bindingMap.at(BindingNumber(entry.binding)));
+        if (std::holds_alternative<TextureBindingInfo>(bindingInfo.bindingLayout) &&
+            entry.textureView && entry.textureView->IsYCbCr()) {
+            DAWN_INVALID_IF(!sampledYcbcrTextures.test(i),
+                            "YCbCr texture at binding (%u) is not sampled by a static sampler.",
+                            entry.binding);
+        }
+    }
+
+    return {};
+}
+
 }  // anonymous namespace
 
 MaybeError ValidateBindGroupDescriptor(DeviceBase* device,
@@ -371,6 +428,8 @@
     const BindGroupLayoutInternalBase::BindingMap& bindingMap = layout->GetBindingMap();
     DAWN_ASSERT(bindingMap.size() <= kMaxBindingsPerPipelineLayout);
 
+    bool needsCrossBindingValidation = layout->NeedsCrossBindingValidation();
+
     ityp::bitset<BindingIndex, kMaxBindingsPerPipelineLayout> bindingsSet;
     for (uint32_t i = 0; i < descriptor->entryCount; ++i) {
         const BindGroupEntry& entry = descriptor->entries[i];
@@ -430,6 +489,11 @@
                                  "validating entries[%u] as a Sampled Texture."
                                  "\nExpected entry layout: %s",
                                  i, layout);
+                if (entry.textureView->IsYCbCr()) {
+                    // Need to validate that the YCbCr texture is statically sampled.
+                    needsCrossBindingValidation = true;
+                }
+
                 return {};
             },
             [&](const StorageTextureBindingInfo& layout) -> MaybeError {
@@ -465,6 +529,12 @@
     // We don't validate the equality because it wouldn't be possible to cover it with a test.
     DAWN_ASSERT(bindingsSet.count() == expectedBindingsCount);
 
+    if (needsCrossBindingValidation) {
+        // This additional validation is only needed when there are static samplers used with a
+        // single texture binding and/or there are YCbCr textures.
+        DAWN_TRY(ValidateStaticSamplersWithSampledTextures(descriptor, layout));
+    }
+
     return {};
 }
 
diff --git a/src/dawn/native/BindGroupLayoutInternal.cpp b/src/dawn/native/BindGroupLayoutInternal.cpp
index 9a4971b..bb4c433 100644
--- a/src/dawn/native/BindGroupLayoutInternal.cpp
+++ b/src/dawn/native/BindGroupLayoutInternal.cpp
@@ -300,7 +300,8 @@
 UnpackedExpandedBglEntries ExtractAndExpandBglEntries(
     const BindGroupLayoutDescriptor* descriptor,
     BindingCounts* bindingCounts,
-    ExternalTextureBindingExpansionMap* externalTextureBindingExpansions) {
+    ExternalTextureBindingExpansionMap* externalTextureBindingExpansions,
+    bool* needsCrossBindingValidation) {
     UnpackedExpandedBglEntries result;
     std::list<BindGroupLayoutEntry>& additionalEntries = result.additionalEntries;
     std::vector<UnpackedPtr<BindGroupLayoutEntry>>& expandedOutput = result.unpackedEntries;
@@ -352,6 +353,12 @@
             externalTextureBindingExpansions->insert(
                 {BindingNumber(entry->binding), bindingExpansion});
         } else {
+            if (auto* staticSamplerBindingLayout = entry.Get<StaticSamplerBindingLayout>()) {
+                if (staticSamplerBindingLayout->sampledTextureBinding != WGPU_LIMIT_U32_UNDEFINED) {
+                    *needsCrossBindingValidation = true;
+                }
+            }
+
             expandedOutput.push_back(entry);
         }
     }
@@ -627,7 +634,8 @@
     ApiObjectBase::UntrackedByDeviceTag tag)
     : ApiObjectBase(device, descriptor->label), mUnexpandedBindingCount(descriptor->entryCount) {
     auto unpackedBindings = ExtractAndExpandBglEntries(descriptor, &mBindingCounts,
-                                                       &mExternalTextureBindingExpansionMap);
+                                                       &mExternalTextureBindingExpansionMap,
+                                                       &mNeedsCrossBindingValidation);
     auto& sortedBindings = unpackedBindings.unpackedEntries;
 
     std::sort(sortedBindings.begin(), sortedBindings.end(), SortBindingsCompare);
@@ -778,6 +786,10 @@
     return mExternalTextureBindingExpansionMap;
 }
 
+bool BindGroupLayoutInternalBase::NeedsCrossBindingValidation() const {
+    return mNeedsCrossBindingValidation;
+}
+
 uint32_t BindGroupLayoutInternalBase::GetUnexpandedBindingCount() const {
     return mUnexpandedBindingCount;
 }
diff --git a/src/dawn/native/BindGroupLayoutInternal.h b/src/dawn/native/BindGroupLayoutInternal.h
index 2d216ac..9c8e6ee 100644
--- a/src/dawn/native/BindGroupLayoutInternal.h
+++ b/src/dawn/native/BindGroupLayoutInternal.h
@@ -113,6 +113,8 @@
 
     uint32_t GetUnexpandedBindingCount() const;
 
+    bool NeedsCrossBindingValidation() const;
+
     // Tests that the BindingInfo of two bind groups are equal.
     bool IsLayoutEqual(const BindGroupLayoutInternalBase* other) const;
 
@@ -158,6 +160,7 @@
     BindGroupLayoutInternalBase(DeviceBase* device, ObjectBase::ErrorTag tag, StringView label);
 
     BindingCounts mBindingCounts = {};
+    bool mNeedsCrossBindingValidation = false;
     ityp::vector<BindingIndex, BindingInfo> mBindingInfo;
 
     // Map from BindGroupLayoutEntry.binding to packed indices.
diff --git a/src/dawn/native/Sampler.cpp b/src/dawn/native/Sampler.cpp
index 5999265..46b2a39 100644
--- a/src/dawn/native/Sampler.cpp
+++ b/src/dawn/native/Sampler.cpp
@@ -138,6 +138,11 @@
     return mIsYCbCr;
 }
 
+YCbCrVkDescriptor SamplerBase::GetYCbCrVkDescriptor() const {
+    DAWN_ASSERT(IsYCbCr());
+    return mYCbCrVkDescriptor;
+}
+
 size_t SamplerBase::ComputeContentHash() {
     ObjectContentHasher recorder;
     // NOTE: We always hash the state of `mYCbCrVkDescriptor` to avoid splitting
diff --git a/src/dawn/native/Sampler.h b/src/dawn/native/Sampler.h
index e7179f2..4f981e1 100644
--- a/src/dawn/native/Sampler.h
+++ b/src/dawn/native/Sampler.h
@@ -59,6 +59,8 @@
     bool IsComparison() const;
     bool IsFiltering() const;
     bool IsYCbCr() const;
+    // Valid to call only if `IsYCbCr()` is true.
+    YCbCrVkDescriptor GetYCbCrVkDescriptor() const;
 
     // Functions necessary for the unordered_set<SamplerBase*>-based cache.
     size_t ComputeContentHash() override;
@@ -72,12 +74,6 @@
   protected:
     void DestroyImpl() override;
 
-    // Valid to call only if `IsYCbCr()` is true.
-    YCbCrVkDescriptor GetYCbCrVkDescriptor() {
-        DAWN_ASSERT(IsYCbCr());
-        return mYCbCrVkDescriptor;
-    }
-
   private:
     SamplerBase(DeviceBase* device, ObjectBase::ErrorTag tag, StringView label);
 
diff --git a/src/dawn/native/Texture.cpp b/src/dawn/native/Texture.cpp
index 06a8e17..3c8cff1 100644
--- a/src/dawn/native/Texture.cpp
+++ b/src/dawn/native/Texture.cpp
@@ -1526,6 +1526,15 @@
     return mInternalUsage;
 }
 
+bool TextureViewBase::IsYCbCr() const {
+    return false;
+}
+
+YCbCrVkDescriptor TextureViewBase::GetYCbCrVkDescriptor() const {
+    DAWN_UNREACHABLE();
+    return {};
+}
+
 ApiObjectList* TextureViewBase::GetObjectTrackingList() {
     if (mTexture != nullptr) {
         return mTexture->GetViewTrackingList();
diff --git a/src/dawn/native/Texture.h b/src/dawn/native/Texture.h
index cc73f4e..19b1d7e 100644
--- a/src/dawn/native/Texture.h
+++ b/src/dawn/native/Texture.h
@@ -254,6 +254,10 @@
     wgpu::TextureUsage GetUsage() const;
     wgpu::TextureUsage GetInternalUsage() const;
 
+    virtual bool IsYCbCr() const;
+    // Valid to call only if `IsYCbCr()` is true.
+    virtual YCbCrVkDescriptor GetYCbCrVkDescriptor() const;
+
   protected:
     void DestroyImpl() override;
 
diff --git a/src/dawn/native/vulkan/TextureVk.cpp b/src/dawn/native/vulkan/TextureVk.cpp
index c31379f..498e57b 100644
--- a/src/dawn/native/vulkan/TextureVk.cpp
+++ b/src/dawn/native/vulkan/TextureVk.cpp
@@ -1873,6 +1873,7 @@
 
     VkSamplerYcbcrConversionInfo samplerYCbCrInfo = {};
     if (auto* yCbCrVkDescriptor = descriptor.Get<YCbCrVkDescriptor>()) {
+        mIsYCbCr = true;
         mYCbCrVkDescriptor = *yCbCrVkDescriptor;
         mYCbCrVkDescriptor.nextInChain = nullptr;
 
@@ -2005,6 +2006,15 @@
     return view;
 }
 
+bool TextureView::IsYCbCr() const {
+    return mIsYCbCr;
+}
+
+YCbCrVkDescriptor TextureView::GetYCbCrVkDescriptor() const {
+    DAWN_ASSERT(IsYCbCr());
+    return mYCbCrVkDescriptor;
+}
+
 void TextureView::SetLabelImpl() {
     SetDebugName(ToBackend(GetDevice()), mHandle, "Dawn_TextureView", GetLabel());
 }
diff --git a/src/dawn/native/vulkan/TextureVk.h b/src/dawn/native/vulkan/TextureVk.h
index 3c7cab8..33ae96f 100644
--- a/src/dawn/native/vulkan/TextureVk.h
+++ b/src/dawn/native/vulkan/TextureVk.h
@@ -308,6 +308,9 @@
 
     ResultOrError<VkImageView> GetOrCreate2DViewOn3D(uint32_t depthSlice = 0u);
 
+    bool IsYCbCr() const override;
+    YCbCrVkDescriptor GetYCbCrVkDescriptor() const override;
+
   private:
     ~TextureView() override;
     void DestroyImpl() override;
@@ -324,6 +327,7 @@
     VkImageView mHandle = VK_NULL_HANDLE;
     VkImageView mHandleForBGRA8UnormStorage = VK_NULL_HANDLE;
     VkSamplerYcbcrConversion mSamplerYCbCrConversion = VK_NULL_HANDLE;
+    bool mIsYCbCr = false;
     YCbCrVkDescriptor mYCbCrVkDescriptor;
     std::vector<VkImageView> mHandlesFor2DViewOn3D;
 };
diff --git a/src/dawn/tests/end2end/YCbCrInfoTests.cpp b/src/dawn/tests/end2end/YCbCrInfoTests.cpp
index 43aa101..f71143a 100644
--- a/src/dawn/tests/end2end/YCbCrInfoTests.cpp
+++ b/src/dawn/tests/end2end/YCbCrInfoTests.cpp
@@ -86,9 +86,11 @@
 #endif
 }
 
-wgpu::TextureViewDescriptor CreateDefaultViewDescriptor(wgpu::TextureViewDimension dimension) {
+wgpu::TextureViewDescriptor CreateDefaultViewDescriptor(
+    wgpu::TextureViewDimension dimension,
+    wgpu::TextureFormat format = kDefaultTextureFormat) {
     wgpu::TextureViewDescriptor descriptor;
-    descriptor.format = kDefaultTextureFormat;
+    descriptor.format = format;
     descriptor.dimension = dimension;
     descriptor.baseMipLevel = 0;
     if (dimension != wgpu::TextureViewDimension::e1D) {
@@ -311,7 +313,10 @@
     wgpu::BindGroupLayout layout = device.CreateBindGroupLayout(&layoutDesc);
 
     wgpu::Texture texture = Create2DTexture(device);
-    wgpu::TextureView textureView = Create2DTextureView(texture, &yCbCrDesc);
+
+    wgpu::YCbCrVkDescriptor yCbCrDescTex = {};
+    yCbCrDescTex.vkFormat = VK_FORMAT_R8G8B8A8_UNORM;
+    wgpu::TextureView textureView = Create2DTextureView(texture, &yCbCrDescTex);
 
     utils::MakeBindGroup(device, layout, {{1, textureView}});
 }
@@ -504,12 +509,110 @@
         device, layout, {{1, device.CreateSampler(&samplerDesc0)}, {2, textureView}}));
 }
 
+// Tests that creating a bind group fails when YCbCr texture isn't sampled by a static sampler.
+TEST_P(YCbCrInfoTest, CreateBindGroupWithoutYCbCrSampler) {
+    std::vector<wgpu::BindGroupLayoutEntry> entries;
+
+    wgpu::YCbCrVkDescriptor yCbCrDesc = {};
+    yCbCrDesc.vkFormat = VK_FORMAT_R8G8B8A8_UNORM;
+
+    wgpu::BindGroupLayoutEntry& binding0 = entries.emplace_back();
+    binding0.binding = 0;
+    binding0.texture.sampleType = wgpu::TextureSampleType::Float;
+    binding0.texture.viewDimension = wgpu::TextureViewDimension::e2D;
+    binding0.texture.multisampled = false;
+
+    wgpu::BindGroupLayoutDescriptor layoutDesc = {};
+    layoutDesc.entryCount = entries.size();
+    layoutDesc.entries = entries.data();
+
+    wgpu::BindGroupLayout layout = device.CreateBindGroupLayout(&layoutDesc);
+
+    wgpu::Texture texture = Create2DTexture(device);
+    wgpu::TextureView textureView = Create2DTextureView(texture, &yCbCrDesc);
+
+    ASSERT_DEVICE_ERROR(utils::MakeBindGroup(device, layout, {{0, textureView}}));
+}
+
+// Tests that creating a bind group fails when a YCbCr static sampler samples a non-YCbCr texture.
+TEST_P(YCbCrInfoTest, CreatBindGroupYCbCrStaticSamplerWrongTexture) {
+    std::vector<wgpu::BindGroupLayoutEntry> entries;
+    wgpu::YCbCrVkDescriptor yCbCrDesc = {};
+    yCbCrDesc.vkFormat = VK_FORMAT_R8G8B8A8_UNORM;
+
+    wgpu::BindGroupLayoutEntry& binding0 = entries.emplace_back();
+    binding0.binding = 0;
+    wgpu::StaticSamplerBindingLayout staticSamplerBinding = {};
+    staticSamplerBinding.sampler = CreateYCbCrSampler(device, &yCbCrDesc);
+    staticSamplerBinding.sampledTextureBinding = 1;
+    binding0.nextInChain = &staticSamplerBinding;
+
+    wgpu::BindGroupLayoutEntry& binding1 = entries.emplace_back();
+    binding1.binding = 1;
+    binding1.texture.sampleType = wgpu::TextureSampleType::Float;
+    binding1.texture.viewDimension = wgpu::TextureViewDimension::e2D;
+    binding1.texture.multisampled = false;
+
+    wgpu::BindGroupLayoutDescriptor layoutDesc = {};
+    layoutDesc.entryCount = entries.size();
+    layoutDesc.entries = entries.data();
+
+    wgpu::BindGroupLayout layout = device.CreateBindGroupLayout(&layoutDesc);
+
+    wgpu::TextureDescriptor descriptor;
+    descriptor.dimension = wgpu::TextureDimension::e2D;
+    descriptor.size.width = 32;
+    descriptor.size.height = 32;
+    descriptor.size.depthOrArrayLayers = kDefaultLayerCount;
+    descriptor.sampleCount = 1u;
+    descriptor.format = wgpu::TextureFormat::RGBA8Snorm;
+    descriptor.mipLevelCount = kDefaultMipLevels;
+    descriptor.usage = wgpu::TextureUsage::TextureBinding;
+    wgpu::Texture texture = device.CreateTexture(&descriptor);
+
+    wgpu::TextureViewDescriptor textureDesc = CreateDefaultViewDescriptor(
+        wgpu::TextureViewDimension::e2D, wgpu::TextureFormat::RGBA8Snorm);
+    textureDesc.arrayLayerCount = 1;
+    wgpu::TextureView textureView = texture.CreateView(&textureDesc);
+
+    ASSERT_DEVICE_ERROR(utils::MakeBindGroup(device, layout, {{1, textureView}}));
+}
+
+// Tests that creating a bind group fails when a non-YCbCr static sampler samples a YCbCr texture.
+TEST_P(YCbCrInfoTest, CreatBindGroupYCbCrTextureWrongStaticSampler) {
+    std::vector<wgpu::BindGroupLayoutEntry> entries;
+
+    wgpu::BindGroupLayoutEntry& binding0 = entries.emplace_back();
+    binding0.binding = 0;
+    wgpu::StaticSamplerBindingLayout staticSamplerBinding = {};
+    wgpu::SamplerDescriptor samplerDesc;
+    staticSamplerBinding.sampler = device.CreateSampler(&samplerDesc);
+    staticSamplerBinding.sampledTextureBinding = 1;
+    binding0.nextInChain = &staticSamplerBinding;
+
+    wgpu::BindGroupLayoutEntry& binding1 = entries.emplace_back();
+    binding1.binding = 1;
+    binding1.texture.sampleType = wgpu::TextureSampleType::Float;
+    binding1.texture.viewDimension = wgpu::TextureViewDimension::e2D;
+    binding1.texture.multisampled = false;
+
+    wgpu::BindGroupLayoutDescriptor layoutDesc = {};
+    layoutDesc.entryCount = entries.size();
+    layoutDesc.entries = entries.data();
+
+    wgpu::BindGroupLayout layout = device.CreateBindGroupLayout(&layoutDesc);
+
+    wgpu::YCbCrVkDescriptor yCbCrDesc = {};
+    yCbCrDesc.vkFormat = VK_FORMAT_R8G8B8A8_UNORM;
+
+    wgpu::Texture texture = Create2DTexture(device);
+    wgpu::TextureView textureView = Create2DTextureView(texture, &yCbCrDesc);
+
+    ASSERT_DEVICE_ERROR(utils::MakeBindGroup(device, layout, {{1, textureView}}));
+}
+
 DAWN_INSTANTIATE_TEST(YCbCrInfoTest, VulkanBackend());
 
-// TODO(crbug.com/dawn/2476): Add test validating binding fails if texture view ycbcr info is
-// different from that on sampler
-// TODO(crbug.com/dawn/2476): Add test validating binding passes if texture view ycbcr info is same
-// as that on sampler
 // TODO(crbug.com/dawn/2476): Add validation that mipLevel, arrayLayers are always 1 along with 2D
 // view dimension (see
 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkImageCreateInfo.html) with