Static samplers are passed to root signature creation

Static samplers are provided via BindGroupLayoutEntry to the BindGroupLayout where it is stored. Then when creating the PipelineLayout (RootSignature) they are turned into static samplers for the pipeline.

Change-Id: I60ac3754b2e24f452dc667ab3e35d6b591f3fffe
Bug: 42241433
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/191381
Commit-Queue: Kai Ninomiya <kainino@chromium.org>
Commit-Queue: Srijan Dhungana <srijan.dhungana6@gmail.com>
Reviewed-by: Austin Eng <enga@chromium.org>
diff --git a/src/dawn/native/d3d12/BindGroupD3D12.cpp b/src/dawn/native/d3d12/BindGroupD3D12.cpp
index 53cd5bd..9a7d3dd 100644
--- a/src/dawn/native/d3d12/BindGroupD3D12.cpp
+++ b/src/dawn/native/d3d12/BindGroupD3D12.cpp
@@ -195,10 +195,7 @@
                 }
             },
             [](const StaticSamplerBindingInfo&) {
-                // Static samplers are handled in the frontend.
-                // TODO(crbug.com/dawn/2483): Implement static samplers in the
-                // D3D12 backend.
-                DAWN_UNREACHABLE();
+                // Static samplers are already initialized in the pipeline layout.
             },
             // No-op as samplers will be later initialized by CreateSamplers().
             [](const SamplerBindingInfo&) {},
diff --git a/src/dawn/native/d3d12/BindGroupLayoutD3D12.cpp b/src/dawn/native/d3d12/BindGroupLayoutD3D12.cpp
index ddaf8fe..2cf9ce6 100644
--- a/src/dawn/native/d3d12/BindGroupLayoutD3D12.cpp
+++ b/src/dawn/native/d3d12/BindGroupLayoutD3D12.cpp
@@ -32,8 +32,10 @@
 #include "dawn/common/BitSetIterator.h"
 #include "dawn/common/MatchVariant.h"
 #include "dawn/native/d3d12/DeviceD3D12.h"
+#include "dawn/native/d3d12/SamplerD3D12.h"
 #include "dawn/native/d3d12/SamplerHeapCacheD3D12.h"
 #include "dawn/native/d3d12/StagingDescriptorAllocatorD3D12.h"
+#include "dawn/native/d3d12/UtilsD3D12.h"
 
 namespace dawn::native::d3d12 {
 namespace {
@@ -54,10 +56,6 @@
             }
         },
         [](const StaticSamplerBindingInfo&) -> D3D12_DESCRIPTOR_RANGE_TYPE {
-            // Static samplers are handled in the frontend.
-            // TODO(crbug.com/dawn/2483): Implement static samplers in the
-            // D3D12 backend.
-            DAWN_UNREACHABLE();
             return D3D12_DESCRIPTOR_RANGE_TYPE_SAMPLER;
         },
         [](const SamplerBindingInfo&) -> D3D12_DESCRIPTOR_RANGE_TYPE {
@@ -104,6 +102,33 @@
             WGPUBindingInfoToDescriptorRangeType(bindingInfo);
         mShaderRegisters[bindingIndex] = uint32_t(bindingInfo.binding);
 
+        // Static samplers aren't stored in the descriptor heap. Handle them separately.
+        if (std::holds_alternative<StaticSamplerBindingInfo>(bindingInfo.bindingLayout)) {
+            const StaticSamplerBindingInfo& staticSamplerBindingInfo =
+                std::get<StaticSamplerBindingInfo>(bindingInfo.bindingLayout);
+
+            Sampler* sampler = ToBackend(staticSamplerBindingInfo.sampler.Get());
+
+            const D3D12_SAMPLER_DESC desc = sampler->GetSamplerDescriptor();
+            D3D12_STATIC_SAMPLER_DESC staticSamplerDesc = {};
+            staticSamplerDesc.ShaderRegister = GetShaderRegister(bindingIndex);
+            staticSamplerDesc.RegisterSpace = kRegisterSpacePlaceholder;
+            staticSamplerDesc.ShaderVisibility = ShaderVisibilityType(bindingInfo.visibility);
+            staticSamplerDesc.AddressU = desc.AddressU;
+            staticSamplerDesc.AddressV = desc.AddressV;
+            staticSamplerDesc.AddressW = desc.AddressW;
+            staticSamplerDesc.Filter = desc.Filter;
+            staticSamplerDesc.MinLOD = desc.MinLOD;
+            staticSamplerDesc.MaxLOD = desc.MaxLOD;
+            staticSamplerDesc.MipLODBias = desc.MipLODBias;
+            staticSamplerDesc.MaxAnisotropy = desc.MaxAnisotropy;
+            staticSamplerDesc.ComparisonFunc = desc.ComparisonFunc;
+
+            mStaticSamplers.push_back(staticSamplerDesc);
+
+            continue;
+        }
+
         // For dynamic resources, Dawn uses root descriptor in D3D12 backend. So there is no
         // need to allocate the descriptor from descriptor heap or create descriptor ranges.
         if (bindingIndex < GetDynamicBufferCount()) {
@@ -132,9 +157,7 @@
         range.Flags = MatchVariant(
             bindingInfo.bindingLayout,
             [](const StaticSamplerBindingInfo&) -> D3D12_DESCRIPTOR_RANGE_FLAGS {
-                // Static samplers are handled in the frontend.
-                // TODO(crbug.com/dawn/2483): Implement static samplers in the
-                // D3D12 backend.
+                // Static samplers should already be handled. This should never be reached.
                 DAWN_UNREACHABLE();
                 return D3D12_DESCRIPTOR_RANGE_FLAG_NONE;
             },
@@ -252,4 +275,8 @@
     return mSamplerDescriptorRanges;
 }
 
+const std::vector<D3D12_STATIC_SAMPLER_DESC>& BindGroupLayout::GetStaticSamplers() const {
+    return mStaticSamplers;
+}
+
 }  // namespace dawn::native::d3d12
diff --git a/src/dawn/native/d3d12/BindGroupLayoutD3D12.h b/src/dawn/native/d3d12/BindGroupLayoutD3D12.h
index 5c8bc81..5912d86 100644
--- a/src/dawn/native/d3d12/BindGroupLayoutD3D12.h
+++ b/src/dawn/native/d3d12/BindGroupLayoutD3D12.h
@@ -73,6 +73,7 @@
 
     const std::vector<D3D12_DESCRIPTOR_RANGE1>& GetCbvUavSrvDescriptorRanges() const;
     const std::vector<D3D12_DESCRIPTOR_RANGE1>& GetSamplerDescriptorRanges() const;
+    const std::vector<D3D12_STATIC_SAMPLER_DESC>& GetStaticSamplers() const;
 
   private:
     BindGroupLayout(Device* device, const BindGroupLayoutDescriptor* descriptor);
@@ -95,6 +96,8 @@
     std::vector<D3D12_DESCRIPTOR_RANGE1> mCbvUavSrvDescriptorRanges;
     std::vector<D3D12_DESCRIPTOR_RANGE1> mSamplerDescriptorRanges;
 
+    std::vector<D3D12_STATIC_SAMPLER_DESC> mStaticSamplers;
+
     MutexProtected<SlabAllocator<BindGroup>> mBindGroupAllocator;
 
     // TODO(https://crbug.com/dawn/2361): Rewrite those members with raw_ptr<T>.
diff --git a/src/dawn/native/d3d12/PhysicalDeviceD3D12.cpp b/src/dawn/native/d3d12/PhysicalDeviceD3D12.cpp
index 95d0499..404f137 100644
--- a/src/dawn/native/d3d12/PhysicalDeviceD3D12.cpp
+++ b/src/dawn/native/d3d12/PhysicalDeviceD3D12.cpp
@@ -145,6 +145,7 @@
     EnableFeature(Feature::R8UnormStorage);
     EnableFeature(Feature::SharedBufferMemoryD3D12Resource);
     EnableFeature(Feature::ShaderModuleCompilationOptions);
+    EnableFeature(Feature::StaticSamplers);
 
     if (AreTimestampQueriesSupported()) {
         EnableFeature(Feature::TimestampQuery);
diff --git a/src/dawn/native/d3d12/PipelineLayoutD3D12.cpp b/src/dawn/native/d3d12/PipelineLayoutD3D12.cpp
index b914f15..4d47ab8 100644
--- a/src/dawn/native/d3d12/PipelineLayoutD3D12.cpp
+++ b/src/dawn/native/d3d12/PipelineLayoutD3D12.cpp
@@ -36,6 +36,7 @@
 #include "dawn/native/d3d12/BindGroupLayoutD3D12.h"
 #include "dawn/native/d3d12/DeviceD3D12.h"
 #include "dawn/native/d3d12/PlatformFunctionsD3D12.h"
+#include "dawn/native/d3d12/UtilsD3D12.h"
 
 using Microsoft::WRL::ComPtr;
 
@@ -55,21 +56,6 @@
 static constexpr uint32_t kInvalidDynamicStorageBufferLengthsParameterIndex =
     std::numeric_limits<uint32_t>::max();
 
-D3D12_SHADER_VISIBILITY ShaderVisibilityType(wgpu::ShaderStage visibility) {
-    DAWN_ASSERT(visibility != wgpu::ShaderStage::None);
-
-    if (visibility == wgpu::ShaderStage::Vertex) {
-        return D3D12_SHADER_VISIBILITY_VERTEX;
-    }
-
-    if (visibility == wgpu::ShaderStage::Fragment) {
-        return D3D12_SHADER_VISIBILITY_PIXEL;
-    }
-
-    // For compute or any two combination of stages, visibility must be ALL
-    return D3D12_SHADER_VISIBILITY_ALL;
-}
-
 D3D12_ROOT_PARAMETER_TYPE RootParameterType(wgpu::BufferBindingType type) {
     switch (type) {
         case wgpu::BufferBindingType::Uniform:
@@ -167,16 +153,20 @@
     // Parameters are D3D12_ROOT_PARAMETER_TYPE which is either a root table, constant, or
     // descriptor.
     std::vector<D3D12_ROOT_PARAMETER1> rootParameters;
+    std::vector<D3D12_STATIC_SAMPLER_DESC> staticSamplers;
 
     size_t rangesCount = 0;
+    size_t staticSamplerCount = 0;
     for (BindGroupIndex group : IterateBitSet(GetBindGroupLayoutsMask())) {
         const BindGroupLayout* bindGroupLayout = ToBackend(GetBindGroupLayout(group));
         rangesCount += bindGroupLayout->GetCbvUavSrvDescriptorRanges().size() +
                        bindGroupLayout->GetSamplerDescriptorRanges().size();
+        staticSamplerCount += bindGroupLayout->GetStaticSamplerCount();
     }
 
     // We are taking pointers to `ranges`, so we cannot let it resize while we're pushing to it.
     std::vector<D3D12_DESCRIPTOR_RANGE1> ranges(rangesCount);
+    staticSamplers.reserve(staticSamplerCount);
 
     uint32_t rangeIndex = 0;
 
@@ -218,6 +208,12 @@
             mSamplerRootParameterInfo[group] = rootParameters.size() - 1;
         }
 
+        // Combine the static samplers from the all of the bind group layouts to one vector.
+        for (auto& samplerDesc : bindGroupLayout->GetStaticSamplers()) {
+            auto& newSampler = staticSamplers.emplace_back(samplerDesc);
+            newSampler.RegisterSpace = static_cast<uint32_t>(group);
+        }
+
         // Init root descriptors in root signatures for dynamic buffer bindings.
         // These are packed at the beginning of the layout binding info.
         mDynamicRootParameterIndices[group].resize(bindGroupLayout->GetDynamicBufferCount());
@@ -335,8 +331,8 @@
     versionedRootSignatureDescriptor.Version = D3D_ROOT_SIGNATURE_VERSION_1_1;
     versionedRootSignatureDescriptor.Desc_1_1.NumParameters = rootParameters.size();
     versionedRootSignatureDescriptor.Desc_1_1.pParameters = rootParameters.data();
-    versionedRootSignatureDescriptor.Desc_1_1.NumStaticSamplers = 0;
-    versionedRootSignatureDescriptor.Desc_1_1.pStaticSamplers = nullptr;
+    versionedRootSignatureDescriptor.Desc_1_1.NumStaticSamplers = staticSamplers.size();
+    versionedRootSignatureDescriptor.Desc_1_1.pStaticSamplers = staticSamplers.data();
     versionedRootSignatureDescriptor.Desc_1_1.Flags =
         D3D12_ROOT_SIGNATURE_FLAG_ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT;
 
diff --git a/src/dawn/native/d3d12/UtilsD3D12.cpp b/src/dawn/native/d3d12/UtilsD3D12.cpp
index ea12054..613bdca 100644
--- a/src/dawn/native/d3d12/UtilsD3D12.cpp
+++ b/src/dawn/native/d3d12/UtilsD3D12.cpp
@@ -118,6 +118,21 @@
     }
 }
 
+D3D12_SHADER_VISIBILITY ShaderVisibilityType(wgpu::ShaderStage visibility) {
+    DAWN_ASSERT(visibility != wgpu::ShaderStage::None);
+
+    if (visibility == wgpu::ShaderStage::Vertex) {
+        return D3D12_SHADER_VISIBILITY_VERTEX;
+    }
+
+    if (visibility == wgpu::ShaderStage::Fragment) {
+        return D3D12_SHADER_VISIBILITY_PIXEL;
+    }
+
+    // For compute or any two combination of stages, visibility must be ALL
+    return D3D12_SHADER_VISIBILITY_ALL;
+}
+
 D3D12_TEXTURE_COPY_LOCATION ComputeTextureCopyLocationForTexture(const Texture* texture,
                                                                  uint32_t level,
                                                                  uint32_t layer,
diff --git a/src/dawn/native/d3d12/UtilsD3D12.h b/src/dawn/native/d3d12/UtilsD3D12.h
index d5f9108..e73c4a7 100644
--- a/src/dawn/native/d3d12/UtilsD3D12.h
+++ b/src/dawn/native/d3d12/UtilsD3D12.h
@@ -42,6 +42,8 @@
 
 D3D12_COMPARISON_FUNC ToD3D12ComparisonFunc(wgpu::CompareFunction func);
 
+D3D12_SHADER_VISIBILITY ShaderVisibilityType(wgpu::ShaderStage visibility);
+
 D3D12_TEXTURE_COPY_LOCATION ComputeTextureCopyLocationForTexture(const Texture* texture,
                                                                  uint32_t level,
                                                                  uint32_t layer,