Fix race condition on metal backend with mCompletedSerial.

There is currently a race condition in the metal backend with the
updating of `mCompletedSerial`. It is currently possible for the
`addCompletedHandler` to set to one value and then have the
`CheckAndUpdateCompletedSerial` call immediately set it back to
a lower value. This can then cause hangs as the serial never
correctly updates again.

This was happening on `dawn.node` when running a large number of
CTS test cases all at the same time.

Change-Id: I28fc58ab2b3737ca8039559718e539ce819e88bb
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/89780
Auto-Submit: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn/native/metal/DeviceMTL.mm b/src/dawn/native/metal/DeviceMTL.mm
index 6f4b4d9..32550d5 100644
--- a/src/dawn/native/metal/DeviceMTL.mm
+++ b/src/dawn/native/metal/DeviceMTL.mm
@@ -295,13 +295,18 @@
 
 ResultOrError<ExecutionSerial> Device::CheckAndUpdateCompletedSerials() {
     uint64_t frontendCompletedSerial{GetCompletedCommandSerial()};
-    if (frontendCompletedSerial > mCompletedSerial) {
-        // sometimes we increase the serials, in which case the completed serial in
-        // the device base will surpass the completed serial we have in the metal backend, so we
-        // must update ours when we see that the completed serial from device base has
-        // increased.
-        mCompletedSerial = frontendCompletedSerial;
+    // sometimes we increase the serials, in which case the completed serial in
+    // the device base will surpass the completed serial we have in the metal backend, so we
+    // must update ours when we see that the completed serial from device base has
+    // increased.
+    //
+    // This update has to be atomic otherwise there is a race with the `addCompletedHandler`
+    // call below and this call could set the mCompletedSerial backwards.
+    uint64_t current = mCompletedSerial.load();
+    while (frontendCompletedSerial > current &&
+           !mCompletedSerial.compare_exchange_weak(current, frontendCompletedSerial)) {
     }
+
     return ExecutionSerial(mCompletedSerial.load());
 }