[dawn][d3d12] Support custom cross-adapter heaps in SBM

Custom shared cross-adapter heap is equivalent to a default heap type with no CPU access.

Bug: 441704688
Change-Id: I4b839727533bfbbed266214e5bfc7f38dd947156
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/308675
Reviewed-by: Rafael Cintron <rafael.cintron@microsoft.com>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Commit-Queue: Bernhart, Bryan <bryan.bernhart@intel.com>
diff --git a/src/dawn/native/d3d12/SharedBufferMemoryD3D12.cpp b/src/dawn/native/d3d12/SharedBufferMemoryD3D12.cpp
index cacf7ec..dea1629 100644
--- a/src/dawn/native/d3d12/SharedBufferMemoryD3D12.cpp
+++ b/src/dawn/native/d3d12/SharedBufferMemoryD3D12.cpp
@@ -53,6 +53,7 @@
 };
 
 ResultOrError<HeapAccessType> MapToHeapAccessType(const D3D12_HEAP_PROPERTIES& heapProperties,
+                                                  const D3D12_HEAP_FLAGS& heapFlags,
                                                   const Device* device) {
     switch (heapProperties.Type) {
         case D3D12_HEAP_TYPE_UPLOAD:
@@ -62,6 +63,12 @@
         case D3D12_HEAP_TYPE_DEFAULT:
             return HeapAccessType::GPUQueueAccessible;
         case D3D12_HEAP_TYPE_CUSTOM:
+            if (heapFlags & D3D12_HEAP_FLAG_SHARED_CROSS_ADAPTER) {
+                // A CUSTOM shared cross-adapter heap is equivalent to a DEFAULT heap.
+                // https://learn.microsoft.com/en-us/windows/win32/direct3d12/shared-heaps
+                return HeapAccessType::GPUQueueAccessible;
+            }
+
             if (device->GetDeviceInfo().isUMA) {
                 // On UMA systems, all heaps are always GPU accessible.
                 return HeapAccessType::GPUQueueAccessible;
@@ -93,10 +100,11 @@
 ResultOrError<SharedBufferMemoryProperties> GetSharedBufferMemoryProperties(
     Device* device,
     D3D12_HEAP_PROPERTIES heapProperties,
+    D3D12_HEAP_FLAGS heapFlags,
     bool allowUAV,
     uint64_t size) {
     HeapAccessType heapType;
-    DAWN_TRY_ASSIGN(heapType, MapToHeapAccessType(heapProperties, device));
+    DAWN_TRY_ASSIGN(heapType, MapToHeapAccessType(heapProperties, heapFlags, device));
 
     wgpu::BufferUsage usages = wgpu::BufferUsage::None;
 
@@ -188,8 +196,8 @@
     bool allowUAV = desc.Flags & D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS;
 
     SharedBufferMemoryProperties properties;
-    DAWN_TRY_ASSIGN(properties,
-                    GetSharedBufferMemoryProperties(device, heapProperties, allowUAV, desc.Width));
+    DAWN_TRY_ASSIGN(properties, GetSharedBufferMemoryProperties(device, heapProperties, heapFlags,
+                                                                allowUAV, desc.Width));
 
     auto result =
         AcquireRef(new SharedBufferMemory(device, label, properties, std::move(d3d12Resource)));
@@ -220,9 +228,10 @@
 
     D3D12_HEAP_DESC heapDesc = d3d12Heap->GetDesc();
     D3D12_HEAP_PROPERTIES heapProperties = heapDesc.Properties;
+    D3D12_HEAP_FLAGS heapFlags = heapDesc.Flags;
     SharedBufferMemoryProperties properties;
-    DAWN_TRY_ASSIGN(properties, GetSharedBufferMemoryProperties(device, heapProperties, true,
-                                                                descriptor->size));
+    DAWN_TRY_ASSIGN(properties, GetSharedBufferMemoryProperties(device, heapProperties, heapFlags,
+                                                                true, descriptor->size));
 
     D3D12_RESOURCE_DESC resourceDescriptor;
     resourceDescriptor.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER;
diff --git a/src/dawn/tests/white_box/SharedBufferMemoryTests_win.cpp b/src/dawn/tests/white_box/SharedBufferMemoryTests_win.cpp
index b1317c8..086360c 100644
--- a/src/dawn/tests/white_box/SharedBufferMemoryTests_win.cpp
+++ b/src/dawn/tests/white_box/SharedBufferMemoryTests_win.cpp
@@ -287,6 +287,52 @@
     ASSERT_TRUE(sharedBufferMemory.CreateBuffer().Get());
 }
 
+// Validate that importing an ID3D12Resource allocated on a CUSTOM cross-adapter heap
+// is equivalent to DEFAULT works correctly.
+TEST_P(SharedBufferMemoryExistingD3D12ResourceTests, CustomCrossAdapterHeapImport) {
+    ComPtr<ID3D12Device> d3d12Device =
+        static_cast<ExistingD3D12ResourceBackend*>(GetParam().mBackend)
+            ->CreateD3D12Device(device, false);
+
+    D3D12_HEAP_PROPERTIES heapProperties = {
+        D3D12_HEAP_TYPE_CUSTOM, D3D12_CPU_PAGE_PROPERTY_NOT_AVAILABLE, D3D12_MEMORY_POOL_L0, 0, 0};
+
+    D3D12_HEAP_DESC heapDesc = {kBufferSize, heapProperties,
+                                D3D12_DEFAULT_RESOURCE_PLACEMENT_ALIGNMENT,
+                                D3D12_HEAP_FLAG_SHARED | D3D12_HEAP_FLAG_SHARED_CROSS_ADAPTER};
+    ComPtr<ID3D12Heap> heap;
+    HRESULT hr = d3d12Device->CreateHeap(&heapDesc, IID_PPV_ARGS(&heap));
+    ASSERT_EQ(hr, S_OK);
+
+    D3D12_RESOURCE_DESC resourceDesc = {};
+    resourceDesc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER;
+    resourceDesc.Alignment = 0;
+    resourceDesc.Width = kBufferSize;
+    resourceDesc.Height = 1;
+    resourceDesc.DepthOrArraySize = 1;
+    resourceDesc.MipLevels = 1;
+    resourceDesc.Format = DXGI_FORMAT_UNKNOWN;
+    resourceDesc.SampleDesc.Count = 1;
+    resourceDesc.SampleDesc.Quality = 0;
+    resourceDesc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR;
+    resourceDesc.Flags =
+        D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS | D3D12_RESOURCE_FLAG_ALLOW_CROSS_ADAPTER;
+
+    ComPtr<ID3D12Resource> d3d12Resource;
+    hr =
+        d3d12Device->CreatePlacedResource(heap.Get(), 0, &resourceDesc, D3D12_RESOURCE_STATE_COMMON,
+                                          nullptr, IID_PPV_ARGS(&d3d12Resource));
+    ASSERT_EQ(hr, S_OK);
+
+    wgpu::SharedBufferMemoryDescriptor desc;
+    native::d3d12::SharedBufferMemoryD3D12ResourceDescriptor sharedD3d12ResourceDesc;
+    sharedD3d12ResourceDesc.resource = d3d12Resource.Get();
+    desc.nextInChain = &sharedD3d12ResourceDesc;
+
+    wgpu::SharedBufferMemory sharedBufferMemory = device.ImportSharedBufferMemory(&desc);
+    ASSERT_TRUE(sharedBufferMemory.CreateBuffer().Get());
+}
+
 class D3D12SharedMemoryFileHandleBackend : public SharedBufferMemoryTestBackend {
   public:
     static Backend GetInstance() {