dawn_wire/client: Add RequestTracker helper

This helper helps ensure correct handling of request maps by:

 - Forcing erasing to happen immediately when acquiring a request. This
   prevents some cases of iterator invalidation if we later change the
   container type.
 - Implements correct closure of all callbacks, including if the
   callbacks themselves add more callbacks.

Bug: dawn:1092

Change-Id: Ia0ba9f050bbf3f0dee846f537910523bebb3bf1b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/63003
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Auto-Submit: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
diff --git a/src/dawn_wire/client/Buffer.cpp b/src/dawn_wire/client/Buffer.cpp
index 2233c81..f27b99e 100644
--- a/src/dawn_wire/client/Buffer.cpp
+++ b/src/dawn_wire/client/Buffer.cpp
@@ -140,25 +140,20 @@
     }
 
     Buffer::~Buffer() {
-        // Callbacks need to be fired in all cases, as they can handle freeing resources
-        // so we call them with "DestroyedBeforeCallback" status.
-        for (auto& it : mRequests) {
-            if (it.second.callback) {
-                it.second.callback(WGPUBufferMapAsyncStatus_DestroyedBeforeCallback, it.second.userdata);
-            }
-        }
-        mRequests.clear();
-
+        ClearAllCallbacks(WGPUBufferMapAsyncStatus_DestroyedBeforeCallback);
         FreeMappedData();
     }
 
     void Buffer::CancelCallbacksForDisconnect() {
-        for (auto& it : mRequests) {
-            if (it.second.callback) {
-                it.second.callback(WGPUBufferMapAsyncStatus_DeviceLost, it.second.userdata);
+        ClearAllCallbacks(WGPUBufferMapAsyncStatus_DeviceLost);
+    }
+
+    void Buffer::ClearAllCallbacks(WGPUBufferMapAsyncStatus status) {
+        mRequests.CloseAll([status](MapRequestData* request) {
+            if (request->callback != nullptr) {
+                request->callback(status, request->userdata);
             }
-        }
-        mRequests.clear();
+        });
     }
 
     void Buffer::MapAsync(WGPUMapModeFlags mode,
@@ -177,10 +172,7 @@
 
         // Create the request structure that will hold information while this mapping is
         // in flight.
-        uint64_t serial = mRequestSerial++;
-        ASSERT(mRequests.find(serial) == mRequests.end());
-
-        Buffer::MapRequestData request = {};
+        MapRequestData request = {};
         request.callback = callback;
         request.userdata = userdata;
         request.offset = offset;
@@ -191,6 +183,8 @@
             request.type = MapRequestType::Write;
         }
 
+        uint64_t serial = mRequests.Add(std::move(request));
+
         // Serialize the command to send to the server.
         BufferMapAsyncCmd cmd;
         cmd.bufferId = this->id;
@@ -200,26 +194,17 @@
         cmd.size = size;
 
         client->SerializeCommand(cmd);
-
-        // Register this request so that we can retrieve it from its serial when the server
-        // sends the callback.
-        mRequests[serial] = std::move(request);
     }
 
     bool Buffer::OnMapAsyncCallback(uint64_t requestSerial,
                                     uint32_t status,
                                     uint64_t readDataUpdateInfoLength,
                                     const uint8_t* readDataUpdateInfo) {
-        auto requestIt = mRequests.find(requestSerial);
-        if (requestIt == mRequests.end()) {
+        MapRequestData request;
+        if (!mRequests.Acquire(requestSerial, &request)) {
             return false;
         }
 
-        auto request = std::move(requestIt->second);
-        // Delete the request before calling the callback otherwise the callback could be fired a
-        // second time. If, for example, buffer.Unmap() is called inside the callback.
-        mRequests.erase(requestIt);
-
         auto FailRequest = [&request]() -> bool {
             if (request.callback != nullptr) {
                 request.callback(WGPUBufferMapAsyncStatus_DeviceLost, request.userdata);
@@ -352,11 +337,11 @@
         mMapSize = 0;
 
         // Tag all mapping requests still in flight as unmapped before callback.
-        for (auto& it : mRequests) {
-            if (it.second.clientStatus == WGPUBufferMapAsyncStatus_Success) {
-                it.second.clientStatus = WGPUBufferMapAsyncStatus_UnmappedBeforeCallback;
+        mRequests.ForAll([](MapRequestData* request) {
+            if (request->clientStatus == WGPUBufferMapAsyncStatus_Success) {
+                request->clientStatus = WGPUBufferMapAsyncStatus_UnmappedBeforeCallback;
             }
-        }
+        });
 
         BufferUnmapCmd cmd;
         cmd.self = ToAPI(this);
@@ -368,11 +353,11 @@
         FreeMappedData();
 
         // Tag all mapping requests still in flight as destroyed before callback.
-        for (auto& it : mRequests) {
-            if (it.second.clientStatus == WGPUBufferMapAsyncStatus_Success) {
-                it.second.clientStatus = WGPUBufferMapAsyncStatus_DestroyedBeforeCallback;
+        mRequests.ForAll([](MapRequestData* request) {
+            if (request->clientStatus == WGPUBufferMapAsyncStatus_Success) {
+                request->clientStatus = WGPUBufferMapAsyncStatus_DestroyedBeforeCallback;
             }
-        }
+        });
 
         BufferDestroyCmd cmd;
         cmd.self = ToAPI(this);
diff --git a/src/dawn_wire/client/Buffer.h b/src/dawn_wire/client/Buffer.h
index a7d3fab..0a24384 100644
--- a/src/dawn_wire/client/Buffer.h
+++ b/src/dawn_wire/client/Buffer.h
@@ -19,8 +19,7 @@
 
 #include "dawn_wire/WireClient.h"
 #include "dawn_wire/client/ObjectBase.h"
-
-#include <map>
+#include "dawn_wire/client/RequestTracker.h"
 
 namespace dawn_wire { namespace client {
 
@@ -52,6 +51,7 @@
 
       private:
         void CancelCallbacksForDisconnect() override;
+        void ClearAllCallbacks(WGPUBufferMapAsyncStatus status);
 
         bool IsMappedForReading() const;
         bool IsMappedForWriting() const;
@@ -86,8 +86,7 @@
 
             MapRequestType type = MapRequestType::None;
         };
-        std::map<uint64_t, MapRequestData> mRequests;
-        uint64_t mRequestSerial = 0;
+        RequestTracker<MapRequestData> mRequests;
         uint64_t mSize = 0;
 
         // Only one mapped pointer can be active at a time because Unmap clears all the in-flight
diff --git a/src/dawn_wire/client/Client.h b/src/dawn_wire/client/Client.h
index 3616e37..fc3758a 100644
--- a/src/dawn_wire/client/Client.h
+++ b/src/dawn_wire/client/Client.h
@@ -19,6 +19,7 @@
 #include <dawn_wire/Wire.h>
 
 #include "common/LinkedList.h"
+#include "common/NonCopyable.h"
 #include "dawn_wire/ChunkedCommandSerializer.h"
 #include "dawn_wire/WireClient.h"
 #include "dawn_wire/WireCmd_autogen.h"
diff --git a/src/dawn_wire/client/Device.cpp b/src/dawn_wire/client/Device.cpp
index 95be206..17f98a5 100644
--- a/src/dawn_wire/client/Device.cpp
+++ b/src/dawn_wire/client/Device.cpp
@@ -48,26 +48,23 @@
     }
 
     Device::~Device() {
-        // Fire pending error scopes
-        auto errorScopes = std::move(mErrorScopes);
-        for (const auto& it : errorScopes) {
-            it.second.callback(WGPUErrorType_Unknown, "Device destroyed before callback",
-                               it.second.userdata);
-        }
+        mErrorScopes.CloseAll([](ErrorScopeData* request) {
+            request->callback(WGPUErrorType_Unknown, "Device destroyed before callback",
+                              request->userdata);
+        });
 
-        auto createPipelineAsyncRequests = std::move(mCreatePipelineAsyncRequests);
-        for (const auto& it : createPipelineAsyncRequests) {
-            if (it.second.createComputePipelineAsyncCallback != nullptr) {
-                it.second.createComputePipelineAsyncCallback(
+        mCreatePipelineAsyncRequests.CloseAll([](CreatePipelineAsyncRequest* request) {
+            if (request->createComputePipelineAsyncCallback != nullptr) {
+                request->createComputePipelineAsyncCallback(
                     WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr,
-                    "Device destroyed before callback", it.second.userdata);
+                    "Device destroyed before callback", request->userdata);
             } else {
-                ASSERT(it.second.createRenderPipelineAsyncCallback != nullptr);
-                it.second.createRenderPipelineAsyncCallback(
+                ASSERT(request->createRenderPipelineAsyncCallback != nullptr);
+                request->createRenderPipelineAsyncCallback(
                     WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr,
-                    "Device destroyed before callback", it.second.userdata);
+                    "Device destroyed before callback", request->userdata);
             }
-        }
+        });
     }
 
     void Device::HandleError(WGPUErrorType errorType, const char* message) {
@@ -91,25 +88,22 @@
     }
 
     void Device::CancelCallbacksForDisconnect() {
-        for (auto& it : mCreatePipelineAsyncRequests) {
-            ASSERT((it.second.createComputePipelineAsyncCallback != nullptr) ^
-                   (it.second.createRenderPipelineAsyncCallback != nullptr));
-            if (it.second.createRenderPipelineAsyncCallback) {
-                it.second.createRenderPipelineAsyncCallback(
-                    WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, "Device lost",
-                    it.second.userdata);
-            } else {
-                it.second.createComputePipelineAsyncCallback(
-                    WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, "Device lost",
-                    it.second.userdata);
-            }
-        }
-        mCreatePipelineAsyncRequests.clear();
+        mErrorScopes.CloseAll([](ErrorScopeData* request) {
+            request->callback(WGPUErrorType_DeviceLost, "Device lost", request->userdata);
+        });
 
-        for (auto& it : mErrorScopes) {
-            it.second.callback(WGPUErrorType_DeviceLost, "Device lost", it.second.userdata);
-        }
-        mErrorScopes.clear();
+        mCreatePipelineAsyncRequests.CloseAll([](CreatePipelineAsyncRequest* request) {
+            if (request->createComputePipelineAsyncCallback != nullptr) {
+                request->createComputePipelineAsyncCallback(
+                    WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, "Device lost",
+                    request->userdata);
+            } else {
+                ASSERT(request->createRenderPipelineAsyncCallback != nullptr);
+                request->createRenderPipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceLost,
+                                                           nullptr, "Device lost",
+                                                           request->userdata);
+            }
+        });
     }
 
     std::weak_ptr<bool> Device::GetAliveWeakPtr() {
@@ -152,10 +146,7 @@
             return true;
         }
 
-        uint64_t serial = mErrorScopeRequestSerial++;
-        ASSERT(mErrorScopes.find(serial) == mErrorScopes.end());
-
-        mErrorScopes[serial] = {callback, userdata};
+        uint64_t serial = mErrorScopes.Add({callback, userdata});
 
         DevicePopErrorScopeCmd cmd;
         cmd.deviceId = this->id;
@@ -180,14 +171,11 @@
                 return false;
         }
 
-        auto requestIt = mErrorScopes.find(requestSerial);
-        if (requestIt == mErrorScopes.end()) {
+        ErrorScopeData request;
+        if (!mErrorScopes.Acquire(requestSerial, &request)) {
             return false;
         }
 
-        ErrorScopeData request = std::move(requestIt->second);
-
-        mErrorScopes.erase(requestIt);
         request.callback(type, message, request.userdata);
         return true;
     }
@@ -265,9 +253,6 @@
                             "GPU device disconnected", userdata);
         }
 
-        DeviceCreateComputePipelineAsyncCmd cmd;
-        cmd.deviceId = this->id;
-
         // Copy compute to the deprecated computeStage or visa-versa, depending on which one is
         // populated, so that serialization doesn't fail.
         // TODO(dawn:800): Remove once computeStage is removed.
@@ -280,35 +265,32 @@
             localDescriptor.compute.entryPoint = localDescriptor.computeStage.entryPoint;
         }
 
-        cmd.descriptor = &localDescriptor;
-
-        uint64_t serial = mCreatePipelineAsyncRequestSerial++;
-        ASSERT(mCreatePipelineAsyncRequests.find(serial) == mCreatePipelineAsyncRequests.end());
-        cmd.requestSerial = serial;
-
         auto* allocation = client->ComputePipelineAllocator().New(client);
+
         CreatePipelineAsyncRequest request = {};
         request.createComputePipelineAsyncCallback = callback;
         request.userdata = userdata;
         request.pipelineObjectID = allocation->object->id;
 
-        cmd.pipelineObjectHandle = ObjectHandle{allocation->object->id, allocation->generation};
-        client->SerializeCommand(cmd);
+        uint64_t serial = mCreatePipelineAsyncRequests.Add(std::move(request));
 
-        mCreatePipelineAsyncRequests[serial] = std::move(request);
+        DeviceCreateComputePipelineAsyncCmd cmd;
+        cmd.deviceId = this->id;
+        cmd.descriptor = &localDescriptor;
+        cmd.requestSerial = serial;
+        cmd.pipelineObjectHandle = ObjectHandle{allocation->object->id, allocation->generation};
+
+        client->SerializeCommand(cmd);
     }
 
     bool Device::OnCreateComputePipelineAsyncCallback(uint64_t requestSerial,
                                                       WGPUCreatePipelineAsyncStatus status,
                                                       const char* message) {
-        const auto& requestIt = mCreatePipelineAsyncRequests.find(requestSerial);
-        if (requestIt == mCreatePipelineAsyncRequests.end()) {
+        CreatePipelineAsyncRequest request;
+        if (!mCreatePipelineAsyncRequests.Acquire(requestSerial, &request)) {
             return false;
         }
 
-        CreatePipelineAsyncRequest request = std::move(requestIt->second);
-        mCreatePipelineAsyncRequests.erase(requestIt);
-
         auto pipelineAllocation =
             client->ComputePipelineAllocator().GetObject(request.pipelineObjectID);
 
@@ -333,37 +315,33 @@
             return callback(WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr,
                             "GPU device disconnected", userdata);
         }
-        DeviceCreateRenderPipelineAsyncCmd cmd;
-        cmd.deviceId = this->id;
-        cmd.descriptor = descriptor;
-
-        uint64_t serial = mCreatePipelineAsyncRequestSerial++;
-        ASSERT(mCreatePipelineAsyncRequests.find(serial) == mCreatePipelineAsyncRequests.end());
-        cmd.requestSerial = serial;
 
         auto* allocation = client->RenderPipelineAllocator().New(client);
+
         CreatePipelineAsyncRequest request = {};
         request.createRenderPipelineAsyncCallback = callback;
         request.userdata = userdata;
         request.pipelineObjectID = allocation->object->id;
 
-        cmd.pipelineObjectHandle = ObjectHandle(allocation->object->id, allocation->generation);
-        client->SerializeCommand(cmd);
+        uint64_t serial = mCreatePipelineAsyncRequests.Add(std::move(request));
 
-        mCreatePipelineAsyncRequests[serial] = std::move(request);
+        DeviceCreateRenderPipelineAsyncCmd cmd;
+        cmd.deviceId = this->id;
+        cmd.descriptor = descriptor;
+        cmd.requestSerial = serial;
+        cmd.pipelineObjectHandle = ObjectHandle(allocation->object->id, allocation->generation);
+
+        client->SerializeCommand(cmd);
     }
 
     bool Device::OnCreateRenderPipelineAsyncCallback(uint64_t requestSerial,
                                                      WGPUCreatePipelineAsyncStatus status,
                                                      const char* message) {
-        const auto& requestIt = mCreatePipelineAsyncRequests.find(requestSerial);
-        if (requestIt == mCreatePipelineAsyncRequests.end()) {
+        CreatePipelineAsyncRequest request;
+        if (!mCreatePipelineAsyncRequests.Acquire(requestSerial, &request)) {
             return false;
         }
 
-        CreatePipelineAsyncRequest request = std::move(requestIt->second);
-        mCreatePipelineAsyncRequests.erase(requestIt);
-
         auto pipelineAllocation =
             client->RenderPipelineAllocator().GetObject(request.pipelineObjectID);
 
diff --git a/src/dawn_wire/client/Device.h b/src/dawn_wire/client/Device.h
index 0bc2ca3..849364f 100644
--- a/src/dawn_wire/client/Device.h
+++ b/src/dawn_wire/client/Device.h
@@ -21,8 +21,8 @@
 #include "dawn_wire/WireCmd_autogen.h"
 #include "dawn_wire/client/ApiObjects_autogen.h"
 #include "dawn_wire/client/ObjectBase.h"
+#include "dawn_wire/client/RequestTracker.h"
 
-#include <map>
 #include <memory>
 
 namespace dawn_wire { namespace client {
@@ -75,8 +75,7 @@
             WGPUErrorCallback callback = nullptr;
             void* userdata = nullptr;
         };
-        std::map<uint64_t, ErrorScopeData> mErrorScopes;
-        uint64_t mErrorScopeRequestSerial = 0;
+        RequestTracker<ErrorScopeData> mErrorScopes;
         uint64_t mErrorScopeStackSize = 0;
 
         struct CreatePipelineAsyncRequest {
@@ -85,8 +84,7 @@
             void* userdata = nullptr;
             ObjectId pipelineObjectID;
         };
-        std::map<uint64_t, CreatePipelineAsyncRequest> mCreatePipelineAsyncRequests;
-        uint64_t mCreatePipelineAsyncRequestSerial = 0;
+        RequestTracker<CreatePipelineAsyncRequest> mCreatePipelineAsyncRequests;
 
         WGPUErrorCallback mErrorCallback = nullptr;
         WGPUDeviceLostCallback mDeviceLostCallback = nullptr;
diff --git a/src/dawn_wire/client/Queue.cpp b/src/dawn_wire/client/Queue.cpp
index 1ac8c77..098ddc5 100644
--- a/src/dawn_wire/client/Queue.cpp
+++ b/src/dawn_wire/client/Queue.cpp
@@ -24,17 +24,11 @@
     }
 
     bool Queue::OnWorkDoneCallback(uint64_t requestSerial, WGPUQueueWorkDoneStatus status) {
-        auto requestIt = mOnWorkDoneRequests.find(requestSerial);
-        if (requestIt == mOnWorkDoneRequests.end()) {
+        OnWorkDoneData request;
+        if (!mOnWorkDoneRequests.Acquire(requestSerial, &request)) {
             return false;
         }
 
-        // Remove the request data so that the callback cannot be called again.
-        // ex.) inside the callback: if the queue is deleted (when there are multiple queues),
-        // all callbacks reject.
-        OnWorkDoneData request = std::move(requestIt->second);
-        mOnWorkDoneRequests.erase(requestIt);
-
         request.callback(status, request.userdata);
         return true;
     }
@@ -47,16 +41,13 @@
             return;
         }
 
-        uint32_t serial = mOnWorkDoneSerial++;
-        ASSERT(mOnWorkDoneRequests.find(serial) == mOnWorkDoneRequests.end());
+        uint64_t serial = mOnWorkDoneRequests.Add({callback, userdata});
 
         QueueOnSubmittedWorkDoneCmd cmd;
         cmd.queueId = this->id;
         cmd.signalValue = signalValue;
         cmd.requestSerial = serial;
 
-        mOnWorkDoneRequests[serial] = {callback, userdata};
-
         client->SerializeCommand(cmd);
     }
 
@@ -97,12 +88,11 @@
     }
 
     void Queue::ClearAllCallbacks(WGPUQueueWorkDoneStatus status) {
-        for (auto& it : mOnWorkDoneRequests) {
-            if (it.second.callback) {
-                it.second.callback(status, it.second.userdata);
+        mOnWorkDoneRequests.CloseAll([status](OnWorkDoneData* request) {
+            if (request->callback != nullptr) {
+                request->callback(status, request->userdata);
             }
-        }
-        mOnWorkDoneRequests.clear();
+        });
     }
 
 }}  // namespace dawn_wire::client
diff --git a/src/dawn_wire/client/Queue.h b/src/dawn_wire/client/Queue.h
index d8e93a3..901acac 100644
--- a/src/dawn_wire/client/Queue.h
+++ b/src/dawn_wire/client/Queue.h
@@ -19,8 +19,7 @@
 
 #include "dawn_wire/WireClient.h"
 #include "dawn_wire/client/ObjectBase.h"
-
-#include <map>
+#include "dawn_wire/client/RequestTracker.h"
 
 namespace dawn_wire { namespace client {
 
@@ -44,15 +43,13 @@
 
       private:
         void CancelCallbacksForDisconnect() override;
-
         void ClearAllCallbacks(WGPUQueueWorkDoneStatus status);
 
         struct OnWorkDoneData {
             WGPUQueueWorkDoneCallback callback = nullptr;
             void* userdata = nullptr;
         };
-        uint64_t mOnWorkDoneSerial = 0;
-        std::map<uint64_t, OnWorkDoneData> mOnWorkDoneRequests;
+        RequestTracker<OnWorkDoneData> mOnWorkDoneRequests;
     };
 
 }}  // namespace dawn_wire::client
diff --git a/src/dawn_wire/client/RequestTracker.h b/src/dawn_wire/client/RequestTracker.h
new file mode 100644
index 0000000..7ce2d00
--- /dev/null
+++ b/src/dawn_wire/client/RequestTracker.h
@@ -0,0 +1,82 @@
+// Copyright 2021 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef DAWNWIRE_CLIENT_REQUESTTRACKER_H_
+#define DAWNWIRE_CLIENT_REQUESTTRACKER_H_
+
+#include "common/Assert.h"
+#include "common/NonCopyable.h"
+
+#include <cstdint>
+#include <map>
+
+namespace dawn_wire { namespace client {
+
+    class Device;
+    class MemoryTransferService;
+
+    template <typename Request>
+    class RequestTracker : NonCopyable {
+      public:
+        ~RequestTracker() {
+            ASSERT(mRequests.empty());
+        }
+
+        uint64_t Add(Request&& request) {
+            mSerial++;
+            mRequests.emplace(mSerial, request);
+            return mSerial;
+        }
+
+        bool Acquire(uint64_t serial, Request* request) {
+            auto it = mRequests.find(serial);
+            if (it == mRequests.end()) {
+                return false;
+            }
+            *request = std::move(it->second);
+            mRequests.erase(it);
+            return true;
+        }
+
+        template <typename CloseFunc>
+        void CloseAll(CloseFunc&& closeFunc) {
+            // Call closeFunc on all requests while handling reentrancy where the callback of some
+            // requests may add some additional requests. We guarantee all callbacks for requests
+            // are called exactly onces, so keep closing new requests if the first batch added more.
+            // It is fine to loop infinitely here if that's what the application makes use do.
+            while (!mRequests.empty()) {
+                // Move mRequests to a local variable so that further reentrant modifications of
+                // mRequests don't invalidate the iterators.
+                auto allRequests = std::move(mRequests);
+                for (auto& it : allRequests) {
+                    closeFunc(&it.second);
+                }
+            }
+        }
+
+        template <typename F>
+        void ForAll(F&& f) {
+            for (auto& it : mRequests) {
+                f(&it.second);
+            }
+        }
+
+      private:
+        uint64_t mSerial;
+        std::map<uint64_t, Request> mRequests;
+    };
+
+}}  // namespace dawn_wire::client
+
+#endif  // DAWNWIRE_CLIENT_REQUESTTRACKER_H_
diff --git a/src/dawn_wire/client/ShaderModule.cpp b/src/dawn_wire/client/ShaderModule.cpp
index fa7945a..c28b978 100644
--- a/src/dawn_wire/client/ShaderModule.cpp
+++ b/src/dawn_wire/client/ShaderModule.cpp
@@ -19,15 +19,7 @@
 namespace dawn_wire { namespace client {
 
     ShaderModule::~ShaderModule() {
-        // Callbacks need to be fired in all cases, as they can handle freeing resources. So we call
-        // them with "Unknown" status.
-        for (auto& it : mCompilationInfoRequests) {
-            if (it.second.callback) {
-                it.second.callback(WGPUCompilationInfoRequestStatus_Unknown, nullptr,
-                                   it.second.userdata);
-            }
-        }
-        mCompilationInfoRequests.clear();
+        ClearAllCallbacks(WGPUCompilationInfoRequestStatus_Unknown);
     }
 
     void ShaderModule::GetCompilationInfo(WGPUCompilationInfoCallback callback, void* userdata) {
@@ -36,41 +28,37 @@
             return;
         }
 
-        uint64_t serial = mCompilationInfoRequestSerial++;
+        uint64_t serial = mCompilationInfoRequests.Add({callback, userdata});
+
         ShaderModuleGetCompilationInfoCmd cmd;
         cmd.shaderModuleId = this->id;
         cmd.requestSerial = serial;
 
-        mCompilationInfoRequests[serial] = {callback, userdata};
-
         client->SerializeCommand(cmd);
     }
 
     bool ShaderModule::GetCompilationInfoCallback(uint64_t requestSerial,
                                                   WGPUCompilationInfoRequestStatus status,
                                                   const WGPUCompilationInfo* info) {
-        auto requestIt = mCompilationInfoRequests.find(requestSerial);
-        if (requestIt == mCompilationInfoRequests.end()) {
+        CompilationInfoRequest request;
+        if (!mCompilationInfoRequests.Acquire(requestSerial, &request)) {
             return false;
         }
 
-        // Remove the request data so that the callback cannot be called again.
-        // ex.) inside the callback: if the shader module is deleted, all callbacks reject.
-        CompilationInfoRequest request = std::move(requestIt->second);
-        mCompilationInfoRequests.erase(requestIt);
-
         request.callback(status, info, request.userdata);
         return true;
     }
 
     void ShaderModule::CancelCallbacksForDisconnect() {
-        for (auto& it : mCompilationInfoRequests) {
-            if (it.second.callback) {
-                it.second.callback(WGPUCompilationInfoRequestStatus_DeviceLost, nullptr,
-                                   it.second.userdata);
+        ClearAllCallbacks(WGPUCompilationInfoRequestStatus_DeviceLost);
+    }
+
+    void ShaderModule::ClearAllCallbacks(WGPUCompilationInfoRequestStatus status) {
+        mCompilationInfoRequests.CloseAll([status](CompilationInfoRequest* request) {
+            if (request->callback != nullptr) {
+                request->callback(status, nullptr, request->userdata);
             }
-        }
-        mCompilationInfoRequests.clear();
+        });
     }
 
 }}  // namespace dawn_wire::client
diff --git a/src/dawn_wire/client/ShaderModule.h b/src/dawn_wire/client/ShaderModule.h
index d7ac55d..f12a4d0 100644
--- a/src/dawn_wire/client/ShaderModule.h
+++ b/src/dawn_wire/client/ShaderModule.h
@@ -17,8 +17,8 @@
 
 #include <dawn/webgpu.h>
 
-#include "common/SerialMap.h"
 #include "dawn_wire/client/ObjectBase.h"
+#include "dawn_wire/client/RequestTracker.h"
 
 namespace dawn_wire { namespace client {
 
@@ -32,15 +32,15 @@
                                         WGPUCompilationInfoRequestStatus status,
                                         const WGPUCompilationInfo* info);
 
-        void CancelCallbacksForDisconnect() override;
-
       private:
+        void CancelCallbacksForDisconnect() override;
+        void ClearAllCallbacks(WGPUCompilationInfoRequestStatus status);
+
         struct CompilationInfoRequest {
             WGPUCompilationInfoCallback callback = nullptr;
             void* userdata = nullptr;
         };
-        uint64_t mCompilationInfoRequestSerial = 0;
-        std::map<uint64_t, CompilationInfoRequest> mCompilationInfoRequests;
+        RequestTracker<CompilationInfoRequest> mCompilationInfoRequests;
     };
 
 }}  // namespace dawn_wire::client