D3D12: Implement `GetAllocatorMemoryInfo()`

This patch implements the memory allocation tracking on D3D12 backend
just like what memory tracking is done on Vulkan backend.

Bug: 407730048
Change-Id: I94130379c1e547a746e8c057f5cfac9ee077f208
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/308395
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn/native/PooledResourceMemoryAllocator.cpp b/src/dawn/native/PooledResourceMemoryAllocator.cpp
index 629d18c..815e638 100644
--- a/src/dawn/native/PooledResourceMemoryAllocator.cpp
+++ b/src/dawn/native/PooledResourceMemoryAllocator.cpp
@@ -75,4 +75,21 @@
 uint64_t PooledResourceMemoryAllocator::GetPoolSizeForTesting() const {
     return mPool.size();
 }
+
+void AllocationSizeTracker::Increment(uint64_t incrementSize) {
+    mTotalSize += incrementSize;
+}
+
+void AllocationSizeTracker::Decrement(ExecutionSerial currentSerial, uint64_t decrementSize) {
+    DAWN_ASSERT(mTotalSize >= decrementSize);
+    mMemoryToDecrement.Enqueue(decrementSize, currentSerial);
+}
+
+void AllocationSizeTracker::Tick(ExecutionSerial completedSerial) {
+    for (uint64_t size : mMemoryToDecrement.IterateUpTo(completedSerial)) {
+        DAWN_ASSERT(mTotalSize >= size);
+        mTotalSize -= size;
+    }
+    mMemoryToDecrement.ClearUpTo(completedSerial);
+}
 }  // namespace dawn::native
diff --git a/src/dawn/native/PooledResourceMemoryAllocator.h b/src/dawn/native/PooledResourceMemoryAllocator.h
index 48f0bb4..a93832d 100644
--- a/src/dawn/native/PooledResourceMemoryAllocator.h
+++ b/src/dawn/native/PooledResourceMemoryAllocator.h
@@ -32,6 +32,7 @@
 #include <memory>
 
 #include "dawn/common/SerialQueue.h"
+#include "dawn/native/IntegerTypes.h"
 #include "dawn/native/ResourceHeapAllocator.h"
 #include "partition_alloc/pointers/raw_ptr.h"
 
@@ -62,6 +63,24 @@
     std::deque<std::unique_ptr<ResourceHeapBase>> mPool;
 };
 
+// Wrapper for tracking the allocation sizes to be decremented up to a completed ExecutionSerial
+// and reporting total allocation/used sizes.
+class AllocationSizeTracker {
+  public:
+    // Increment the total size for tracking.
+    void Increment(uint64_t incrementSize);
+    // Track the size to be decremented on Tick.
+    void Decrement(ExecutionSerial currentSerial, uint64_t decrementSize);
+    // Update the total size after completed serials.
+    void Tick(ExecutionSerial completedSerial);
+
+    uint64_t GetSize() const { return mTotalSize; }
+
+  private:
+    SerialQueue<ExecutionSerial, uint64_t> mMemoryToDecrement;
+    uint64_t mTotalSize = 0;
+};
+
 }  // namespace dawn::native
 
 #endif  // SRC_DAWN_NATIVE_POOLEDRESOURCEMEMORYALLOCATOR_H_
diff --git a/src/dawn/native/d3d12/DeviceD3D12.cpp b/src/dawn/native/d3d12/DeviceD3D12.cpp
index 06594a7..89f5842 100644
--- a/src/dawn/native/d3d12/DeviceD3D12.cpp
+++ b/src/dawn/native/d3d12/DeviceD3D12.cpp
@@ -901,4 +901,13 @@
     return mDxcShaderProfiles;
 }
 
+AllocatorMemoryInfo Device::GetAllocatorMemoryInfo() const {
+    DAWN_ASSERT(IsLockedByCurrentThreadIfNeeded());
+    AllocatorMemoryInfo info = {};
+    info.totalAllocatedMemory = (*mResourceAllocatorManager)->GetTotalAllocatedMemory();
+    info.totalUsedMemory = (*mResourceAllocatorManager)->GetTotalUsedMemory();
+    // D3D12 has no lazy memory concept, leave lazy fields as zero.
+    return info;
+}
+
 }  // namespace dawn::native::d3d12
diff --git a/src/dawn/native/d3d12/DeviceD3D12.h b/src/dawn/native/d3d12/DeviceD3D12.h
index 4babdcd..04d5003 100644
--- a/src/dawn/native/d3d12/DeviceD3D12.h
+++ b/src/dawn/native/d3d12/DeviceD3D12.h
@@ -164,6 +164,8 @@
 
     uint64_t GetBufferCopyOffsetAlignmentForDepthStencil() const override;
 
+    AllocatorMemoryInfo GetAllocatorMemoryInfo() const override;
+
     // Dawn APIs
     void SetLabelImpl() override;
 
diff --git a/src/dawn/native/d3d12/HeapAllocatorD3D12.cpp b/src/dawn/native/d3d12/HeapAllocatorD3D12.cpp
index b151909..e76a972 100644
--- a/src/dawn/native/d3d12/HeapAllocatorD3D12.cpp
+++ b/src/dawn/native/d3d12/HeapAllocatorD3D12.cpp
@@ -29,6 +29,7 @@
 
 #include <utility>
 
+#include "dawn/native/Queue.h"
 #include "dawn/native/d3d/D3DError.h"
 #include "dawn/native/d3d12/DeviceD3D12.h"
 #include "dawn/native/d3d12/HeapD3D12.h"
@@ -40,11 +41,13 @@
 HeapAllocator::HeapAllocator(Device* device,
                              ResourceHeapKind resourceHeapKind,
                              D3D12_HEAP_FLAGS heapFlags,
-                             MemorySegment memorySegment)
+                             MemorySegment memorySegment,
+                             AllocationSizeTracker* allocationMemoryTracker)
     : mDevice(device),
       mResourceHeapKind(resourceHeapKind),
       mHeapFlags(heapFlags),
-      mMemorySegment(memorySegment) {}
+      mMemorySegment(memorySegment),
+      mAllocationMemoryTracker(allocationMemoryTracker) {}
 
 ResultOrError<std::unique_ptr<ResourceHeapBase>> HeapAllocator::AllocateResourceHeap(
     uint64_t size) {
@@ -67,6 +70,7 @@
 
     std::unique_ptr<ResourceHeapBase> heapBase =
         std::make_unique<Heap>(std::move(d3d12Heap), mMemorySegment, size);
+    mAllocationMemoryTracker->Increment(size);
 
     // Calling CreateHeap implicitly calls MakeResident on the new heap. We must track this to
     // avoid calling MakeResident a second time.
@@ -75,7 +79,10 @@
 }
 
 void HeapAllocator::DeallocateResourceHeap(std::unique_ptr<ResourceHeapBase> heap) {
-    mDevice->ReferenceUntilUnused(static_cast<Heap*>(heap.get())->GetD3D12Heap());
+    Heap* d3d12Heap = static_cast<Heap*>(heap.get());
+    mDevice->ReferenceUntilUnused(d3d12Heap->GetD3D12Heap());
+    mAllocationMemoryTracker->Decrement(mDevice->GetQueue()->GetPendingCommandSerial(),
+                                        d3d12Heap->GetSize());
 }
 
 }  // namespace dawn::native::d3d12
diff --git a/src/dawn/native/d3d12/HeapAllocatorD3D12.h b/src/dawn/native/d3d12/HeapAllocatorD3D12.h
index 72789a43..3eaee93 100644
--- a/src/dawn/native/d3d12/HeapAllocatorD3D12.h
+++ b/src/dawn/native/d3d12/HeapAllocatorD3D12.h
@@ -31,6 +31,7 @@
 #include <memory>
 
 #include "dawn/native/D3D12Backend.h"
+#include "dawn/native/PooledResourceMemoryAllocator.h"
 #include "dawn/native/ResourceHeapAllocator.h"
 #include "dawn/native/d3d12/ResourceAllocatorManagerD3D12.h"
 #include "dawn/native/d3d12/d3d12_platform.h"
@@ -46,7 +47,8 @@
     HeapAllocator(Device* device,
                   ResourceHeapKind resourceHeapKind,
                   D3D12_HEAP_FLAGS heapFlags,
-                  MemorySegment memorySegment);
+                  MemorySegment memorySegment,
+                  AllocationSizeTracker* allocationMemoryTracker);
     ~HeapAllocator() override = default;
 
     ResultOrError<std::unique_ptr<ResourceHeapBase>> AllocateResourceHeap(uint64_t size) override;
@@ -57,6 +59,8 @@
     ResourceHeapKind mResourceHeapKind;
     D3D12_HEAP_FLAGS mHeapFlags;
     MemorySegment mMemorySegment;
+    // Owned by ResourceAllocatorManager, which creates and outlives this HeapAllocator.
+    raw_ptr<AllocationSizeTracker> mAllocationMemoryTracker;
 };
 
 }  // namespace dawn::native::d3d12
diff --git a/src/dawn/native/d3d12/ResourceAllocatorManagerD3D12.cpp b/src/dawn/native/d3d12/ResourceAllocatorManagerD3D12.cpp
index 7b61a99..fc8c9c7 100644
--- a/src/dawn/native/d3d12/ResourceAllocatorManagerD3D12.cpp
+++ b/src/dawn/native/d3d12/ResourceAllocatorManagerD3D12.cpp
@@ -305,7 +305,8 @@
         const ResourceHeapKind resourceHeapKind = static_cast<ResourceHeapKind>(i);
         D3D12_HEAP_FLAGS heapFlags = GetD3D12HeapFlags(resourceHeapKind) | createNotZeroedHeapFlag;
         mHeapAllocators[i] = std::make_unique<HeapAllocator>(
-            mDevice, resourceHeapKind, heapFlags, GetMemorySegment(mDevice, resourceHeapKind));
+            mDevice, resourceHeapKind, heapFlags, GetMemorySegment(mDevice, resourceHeapKind),
+            &mAllocatedMemory);
         mPooledHeapAllocators[i] =
             std::make_unique<PooledResourceMemoryAllocator>(mHeapAllocators[i].get());
         mSubAllocatedResourceAllocators[i] = std::make_unique<BuddyMemoryAllocator>(
@@ -403,6 +404,9 @@
     }
     mAllocationsToDelete.ClearUpTo(completedSerial);
     mHeapsToDelete.ClearUpTo(completedSerial);
+
+    mAllocatedMemory.Tick(completedSerial);
+    mUsedMemory.Tick(completedSerial);
 }
 
 void ResourceAllocatorManager::DeallocateMemory(ResourceHeapAllocation& allocation) {
@@ -420,6 +424,14 @@
     if (allocation.GetInfo().mMethod == AllocationMethod::kDirect) {
         mHeapsToDelete.Enqueue(std::unique_ptr<ResourceHeapBase>(allocation.GetResourceHeap()),
                                mDevice->GetQueue()->GetPendingCommandSerial());
+
+        mUsedMemory.Decrement(mDevice->GetQueue()->GetPendingCommandSerial(),
+                              allocation.GetInfo().mRequestedSize);
+        mAllocatedMemory.Decrement(mDevice->GetQueue()->GetPendingCommandSerial(),
+                                   allocation.GetInfo().mRequestedSize);
+    } else if (allocation.GetInfo().mMethod == AllocationMethod::kSubAllocated) {
+        mUsedMemory.Decrement(mDevice->GetQueue()->GetPendingCommandSerial(),
+                              allocation.GetInfo().mRequestedSize);
     }
 
     // Invalidate the allocation immediately in case one accidentally
@@ -502,6 +514,8 @@
             optimizedClearValue, IID_PPV_ARGS(&placedResource)),
         "ID3D12Device::CreatePlacedResource"));
 
+    mUsedMemory.Increment(resourceInfo.SizeInBytes);
+
     // After CreatePlacedResource has finished, the heap can be unlocked from residency. This
     // will insert it into the residency LRU.
     mDevice->GetResidencyManager()->UnlockAllocation(heap);
@@ -556,6 +570,9 @@
             optimizedClearValue, IID_PPV_ARGS(&committedResource)),
         "ID3D12Device::CreateCommittedResource"));
 
+    mAllocatedMemory.Increment(resourceInfo.SizeInBytes);
+    mUsedMemory.Increment(resourceInfo.SizeInBytes);
+
     // When using CreateCommittedResource, D3D12 creates an implicit heap that contains the
     // resource allocation. Because Dawn's memory residency management occurs at the resource
     // heap granularity, every directly allocated ResourceHeapAllocation also stores a Heap
@@ -583,4 +600,12 @@
     }
 }
 
+uint64_t ResourceAllocatorManager::GetTotalAllocatedMemory() const {
+    return mAllocatedMemory.GetSize();
+}
+
+uint64_t ResourceAllocatorManager::GetTotalUsedMemory() const {
+    return mUsedMemory.GetSize();
+}
+
 }  // namespace dawn::native::d3d12
diff --git a/src/dawn/native/d3d12/ResourceAllocatorManagerD3D12.h b/src/dawn/native/d3d12/ResourceAllocatorManagerD3D12.h
index 9f0678a..419f6e2 100644
--- a/src/dawn/native/d3d12/ResourceAllocatorManagerD3D12.h
+++ b/src/dawn/native/d3d12/ResourceAllocatorManagerD3D12.h
@@ -90,6 +90,9 @@
 
     void Tick(ExecutionSerial lastCompletedSerial);
 
+    uint64_t GetTotalAllocatedMemory() const;
+    uint64_t GetTotalUsedMemory() const;
+
   private:
     void FreeSubAllocatedMemory(ResourceHeapAllocation& allocation);
 
@@ -121,6 +124,9 @@
 
     SerialQueue<ExecutionSerial, ResourceHeapAllocation> mAllocationsToDelete;
     SerialQueue<ExecutionSerial, std::unique_ptr<ResourceHeapBase>> mHeapsToDelete;
+
+    AllocationSizeTracker mAllocatedMemory;
+    AllocationSizeTracker mUsedMemory;
 };
 
 }  // namespace dawn::native::d3d12
diff --git a/src/dawn/native/vulkan/ResourceMemoryAllocatorVk.cpp b/src/dawn/native/vulkan/ResourceMemoryAllocatorVk.cpp
index f51e393..2d9e1a9 100644
--- a/src/dawn/native/vulkan/ResourceMemoryAllocatorVk.cpp
+++ b/src/dawn/native/vulkan/ResourceMemoryAllocatorVk.cpp
@@ -134,28 +134,6 @@
     BuddyMemoryAllocator mBuddySystem;
 };
 
-void ResourceMemoryAllocator::AllocationSizeTracker::Increment(VkDeviceSize incrementSize) {
-    mTotalSize += incrementSize;
-}
-
-void ResourceMemoryAllocator::AllocationSizeTracker::Decrement(ExecutionSerial currentSerial,
-                                                               VkDeviceSize decrementSize) {
-    DAWN_ASSERT(mTotalSize >= decrementSize);
-    mMemoryToDecrement[currentSerial] += decrementSize;
-}
-
-void ResourceMemoryAllocator::AllocationSizeTracker::Tick(ExecutionSerial completedSerial) {
-    auto it = mMemoryToDecrement.begin();
-    while (it != mMemoryToDecrement.end() && it->first <= completedSerial) {
-        // Update tracking for allocation/used memory that will be deallocated.
-        DAWN_ASSERT(mTotalSize >= it->second);
-        mTotalSize -= it->second;
-        it++;
-    }
-    // Erase the map serials up to the completed serial.
-    mMemoryToDecrement.erase(mMemoryToDecrement.begin(), it);
-}
-
 VkDeviceSize ResourceMemoryAllocator::GetHeapBlockSize(const DawnDeviceAllocatorControl* control) {
     static constexpr VkDeviceSize kDefaultHeapBlockSize = 8ull * 1024ull * 1024ull;  // 8MiB
     VkDeviceSize heapBlockSize = kDefaultHeapBlockSize;
@@ -367,19 +345,19 @@
 }
 
 uint64_t ResourceMemoryAllocator::GetTotalUsedMemory() const {
-    return mUsedMemory.Size();
+    return mUsedMemory.GetSize();
 }
 
 uint64_t ResourceMemoryAllocator::GetTotalAllocatedMemory() const {
-    return mAllocatedMemory.Size();
+    return mAllocatedMemory.GetSize();
 }
 
 uint64_t ResourceMemoryAllocator::GetTotalLazyAllocatedMemory() const {
-    return mLazyAllocatedMemory.Size();
+    return mLazyAllocatedMemory.GetSize();
 }
 
 uint64_t ResourceMemoryAllocator::GetTotalLazyUsedMemory() const {
-    return mLazyUsedMemory.Size();
+    return mLazyUsedMemory.GetSize();
 }
 
 }  // namespace dawn::native::vulkan
diff --git a/src/dawn/native/vulkan/ResourceMemoryAllocatorVk.h b/src/dawn/native/vulkan/ResourceMemoryAllocatorVk.h
index 87616b7..e2a1e25 100644
--- a/src/dawn/native/vulkan/ResourceMemoryAllocatorVk.h
+++ b/src/dawn/native/vulkan/ResourceMemoryAllocatorVk.h
@@ -28,7 +28,6 @@
 #ifndef SRC_DAWN_NATIVE_VULKAN_RESOURCEMEMORYALLOCATORVK_H_
 #define SRC_DAWN_NATIVE_VULKAN_RESOURCEMEMORYALLOCATORVK_H_
 
-#include <map>
 #include <memory>
 #include <vector>
 
@@ -83,24 +82,6 @@
     void DeallocateResourceHeap(ResourceHeap* heap, bool isLazyMemoryType);
 
   private:
-    // Wrapper for tracking the allocation sizes to be decremented up to a completed ExecutionSerial
-    // and reporting total allocation/used sizes.
-    class AllocationSizeTracker {
-      public:
-        // Increment the total size for tracking.
-        void Increment(VkDeviceSize incrementSize);
-        // Track the size to be decremented on Tick.
-        void Decrement(ExecutionSerial currentSerial, VkDeviceSize decrementSize);
-        // Update the total size after completed serials.
-        void Tick(ExecutionSerial completedSerial);
-
-        VkDeviceSize Size() const { return mTotalSize; }
-
-      private:
-        std::map<ExecutionSerial, VkDeviceSize> mMemoryToDecrement;
-        VkDeviceSize mTotalSize = 0;
-    };
-
     raw_ptr<Device> mDevice;
     const VkDeviceSize mMaxSizeForSuballocation;
     MemoryTypeSelector mMemoryTypeSelector;
diff --git a/src/dawn/tests/end2end/AllocatorMemoryInstrumentationTests.cpp b/src/dawn/tests/end2end/AllocatorMemoryInstrumentationTests.cpp
index 792caac..26e4a93 100644
--- a/src/dawn/tests/end2end/AllocatorMemoryInstrumentationTests.cpp
+++ b/src/dawn/tests/end2end/AllocatorMemoryInstrumentationTests.cpp
@@ -42,7 +42,10 @@
 };
 
 // Test the detailed memory usage reported by GetAllocatorMemoryInfo()
-TEST_P(AllocatorMemoryInstrumentationTest, GetAllocatorMemoryInfoVulkan) {
+TEST_P(AllocatorMemoryInstrumentationTest, GetAllocatorMemoryInfo) {
+    native::AllocatorMemoryInfo memInfo = native::GetAllocatorMemoryInfo(device.Get());
+    auto usedMemoryInInitialization = memInfo.totalUsedMemory;
+
     // Create a buffer with size 32.
     constexpr uint64_t kBufferSize = 32;
     constexpr wgpu::BufferDescriptor kBufferDesc = {
@@ -50,11 +53,11 @@
         .size = kBufferSize,
     };
 
-    // Creating the buffer should allocate memory with Vulkan ResourceMemoryAllocator.
+    // Creating the buffer should allocate memory with ResourceMemoryAllocator.
     wgpu::Buffer uniformBuffer = device.CreateBuffer(&kBufferDesc);
     EXPECT_TRUE(uniformBuffer);
 
-    native::AllocatorMemoryInfo memInfo = native::GetAllocatorMemoryInfo(device.Get());
+    memInfo = native::GetAllocatorMemoryInfo(device.Get());
     EXPECT_GT(memInfo.totalAllocatedMemory, 0u);
     EXPECT_GT(memInfo.totalUsedMemory, 0u);
     EXPECT_GE(memInfo.totalAllocatedMemory, memInfo.totalUsedMemory);
@@ -81,12 +84,11 @@
     device.Tick();
 
     memInfo = native::GetAllocatorMemoryInfo(device.Get());
-    // Vulkan used memory should be 0 now.
-    EXPECT_EQ(memInfo.totalUsedMemory, 0u);
+    EXPECT_EQ(memInfo.totalUsedMemory, usedMemoryInInitialization);
     EXPECT_LE(memInfo.totalAllocatedMemory, prevAllocatedMemory);
 }
 
-DAWN_INSTANTIATE_TEST(AllocatorMemoryInstrumentationTest, VulkanBackend());
+DAWN_INSTANTIATE_TEST(AllocatorMemoryInstrumentationTest, D3D12Backend(), VulkanBackend());
 
 }  // anonymous namespace
 }  // namespace dawn