Add wire serialized command buffer padding.

Pads serialized wire command buffers to 8 bytes so that we don't have
misaligned write/reads which can cause SIGILL depending on platform and
compilation mode, i.e. -c dbg in google3 builds.

- Adds helpers for aligning sizeof calls.
- Adds constant for wire padding (8u).
- Modifies BufferConsumer to allocate according to padding. This
  guarantees that when we [de]serialize stuff, the padding should be
  equal on both sides.
- Modifies extra byte serialization code (adding CommandExtension
  struct). This makes it clearer that each extension needs to be
  padded independently. Otherwise, before in wire/client/Buffer.cpp,
  since the read/write handle sizes were being passed as a sum, but
  read out separately from the BufferConsumer, we corrupt our pointers.
- Adds some simple unit tests.

Bug: dawn:1334
Change-Id: Id80e7c01a34b9f01c3f02b3e6c04c3bb3ad0eff9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/110501
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Loko Kung <lokokung@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/generator/templates/dawn/wire/WireCmd.cpp b/generator/templates/dawn/wire/WireCmd.cpp
index 26c0291..10a7d36 100644
--- a/generator/templates/dawn/wire/WireCmd.cpp
+++ b/generator/templates/dawn/wire/WireCmd.cpp
@@ -165,7 +165,7 @@
                 if (has_{{memberName}})
             {% endif %}
             {
-            result += std::strlen(record.{{memberName}});
+            result += Align(std::strlen(record.{{memberName}}), kWireBufferAlignment);
             }
         {% endfor %}
 
@@ -178,7 +178,9 @@
                 {% if member.annotation != "value" %}
                     {{ assert(member.annotation != "const*const*") }}
                     auto memberLength = {{member_length(member, "record.")}};
-                    result += memberLength * {{member_transfer_sizeof(member)}};
+                    auto size = WireAlignSizeofN<{{member_transfer_type(member)}}>(memberLength);
+                    ASSERT(size);
+                    result += *size;
                     //* Structures might contain more pointers so we need to add their extra size as well.
                     {% if member.type.category == "structure" %}
                         for (decltype(memberLength) i = 0; i < memberLength; ++i) {
@@ -431,7 +433,7 @@
     {% set Cmd = Name + "Cmd" %}
 
     size_t {{Cmd}}::GetRequiredSize() const {
-        size_t size = sizeof({{Name}}Transfer) + {{Name}}GetExtraRequiredSize(*this);
+        size_t size = WireAlignSizeof<{{Name}}Transfer>() + {{Name}}GetExtraRequiredSize(*this);
         return size;
     }
 
@@ -509,7 +511,7 @@
                     ) %}
                         case {{as_cEnum(types["s type"].name, sType.name)}}: {
                             const auto& typedStruct = *reinterpret_cast<{{as_cType(sType.name)}} const *>(chainedStruct);
-                            result += sizeof({{as_cType(sType.name)}}Transfer);
+                            result += WireAlignSizeof<{{as_cType(sType.name)}}Transfer>();
                             result += {{as_cType(sType.name)}}GetExtraRequiredSize(typedStruct);
                             chainedStruct = typedStruct.chain.next;
                             break;
@@ -519,7 +521,7 @@
                     case WGPUSType_Invalid:
                     default:
                         // Invalid enum. Reserve space just for the transfer header (sType and hasNext).
-                        result += sizeof(WGPUChainedStructTransfer);
+                        result += WireAlignSizeof<WGPUChainedStructTransfer>();
                         chainedStruct = chainedStruct->next;
                         break;
                 }
@@ -600,7 +602,7 @@
                             WIRE_TRY(deserializeBuffer->Read(&transfer));
 
                             {{CType}}* outStruct;
-                            WIRE_TRY(GetSpace(allocator, sizeof({{CType}}), &outStruct));
+                            WIRE_TRY(GetSpace(allocator, 1u, &outStruct));
                             outStruct->chain.sType = sType;
                             outStruct->chain.next = nullptr;
 
@@ -629,7 +631,7 @@
                         WIRE_TRY(deserializeBuffer->Read(&transfer));
 
                         {{ChainedStruct}}* outStruct;
-                        WIRE_TRY(GetSpace(allocator, sizeof({{ChainedStruct}}), &outStruct));
+                        WIRE_TRY(GetSpace(allocator, 1u, &outStruct));
                         outStruct->sType = WGPUSType_Invalid;
                         outStruct->next = nullptr;
 
@@ -654,13 +656,23 @@
         // Always writes to |out| on success.
         template <typename T, typename N>
         WireResult GetSpace(DeserializeAllocator* allocator, N count, T** out) {
-            constexpr size_t kMaxCountWithoutOverflows = std::numeric_limits<size_t>::max() / sizeof(T);
-            if (count > kMaxCountWithoutOverflows) {
+            // Because we use this function extensively when `count` == 1, we can optimize the
+            // size computations a bit more for those cases via constexpr version of the
+            // alignment computation.
+            constexpr size_t kSizeofT = WireAlignSizeof<T>();
+            size_t size = 0;
+            if (count == 1) {
+              size = kSizeofT;
+            } else {
+              auto sizeN = WireAlignSizeofN<T>(count);
+              // A size of 0 indicates an overflow, so return an error.
+              if (!sizeN) {
                 return WireResult::FatalError;
+              }
+              size = *sizeN;
             }
 
-            size_t totalSize = sizeof(T) * count;
-            *out = static_cast<T*>(allocator->GetSpace(totalSize));
+            *out = static_cast<T*>(allocator->GetSpace(size));
             if (*out == nullptr) {
                 return WireResult::FatalError;
             }
diff --git a/src/dawn/common/Constants.h b/src/dawn/common/Constants.h
index bf1f132..027b6c0 100644
--- a/src/dawn/common/Constants.h
+++ b/src/dawn/common/Constants.h
@@ -15,6 +15,7 @@
 #ifndef SRC_DAWN_COMMON_CONSTANTS_H_
 #define SRC_DAWN_COMMON_CONSTANTS_H_
 
+#include <cstddef>
 #include <cstdint>
 
 static constexpr uint32_t kMaxBindGroups = 4u;
@@ -65,4 +66,7 @@
 static constexpr uint8_t kSamplersPerExternalTexture = 1u;
 static constexpr uint8_t kUniformsPerExternalTexture = 1u;
 
+// Wire buffer alignments.
+static constexpr size_t kWireBufferAlignment = 8u;
+
 #endif  // SRC_DAWN_COMMON_CONSTANTS_H_
diff --git a/src/dawn/common/Math.h b/src/dawn/common/Math.h
index 9984c4b..c8b518f 100644
--- a/src/dawn/common/Math.h
+++ b/src/dawn/common/Math.h
@@ -20,6 +20,7 @@
 #include <cstring>
 
 #include <limits>
+#include <optional>
 #include <type_traits>
 
 #include "dawn/common/Assert.h"
@@ -61,6 +62,26 @@
     return (value + (alignmentT - 1)) & ~(alignmentT - 1);
 }
 
+template <typename T, size_t Alignment>
+constexpr size_t AlignSizeof() {
+    static_assert(Alignment != 0 && (Alignment & (Alignment - 1)) == 0,
+                  "Alignment must be a valid power of 2.");
+    static_assert(sizeof(T) <= std::numeric_limits<size_t>::max() - (Alignment - 1));
+    return (sizeof(T) + (Alignment - 1)) & ~(Alignment - 1);
+}
+
+// Returns an aligned size for an n-sized array of T elements. If the size would overflow, returns
+// nullopt instead.
+template <typename T, size_t Alignment>
+std::optional<size_t> AlignSizeofN(uint64_t n) {
+    constexpr uint64_t kMaxCountWithoutOverflows =
+        (std::numeric_limits<size_t>::max() - Alignment + 1) / sizeof(T);
+    if (n > kMaxCountWithoutOverflows) {
+        return std::nullopt;
+    }
+    return Align(sizeof(T) * n, Alignment);
+}
+
 template <typename T>
 DAWN_FORCE_INLINE T* AlignPtr(T* ptr, size_t alignment) {
     ASSERT(IsPowerOfTwo(alignment));
diff --git a/src/dawn/tests/unittests/MathTests.cpp b/src/dawn/tests/unittests/MathTests.cpp
index d88e858..a38fb67 100644
--- a/src/dawn/tests/unittests/MathTests.cpp
+++ b/src/dawn/tests/unittests/MathTests.cpp
@@ -13,6 +13,7 @@
 // limitations under the License.
 
 #include <cmath>
+#include <limits>
 #include <vector>
 
 #include "dawn/EnumClassBitmasks.h"
@@ -180,6 +181,41 @@
     ASSERT_EQ(Align(static_cast<uint64_t>(0xFFFFFFFFFFFFFFFF), 1), 0xFFFFFFFFFFFFFFFFull);
 }
 
+TEST(Math, AlignSizeof) {
+    // Basic types should align to self if alignment is a divisor.
+    ASSERT_EQ((AlignSizeof<uint8_t, 1>()), 1u);
+
+    ASSERT_EQ((AlignSizeof<uint16_t, 1>()), 2u);
+    ASSERT_EQ((AlignSizeof<uint16_t, 2>()), 2u);
+
+    ASSERT_EQ((AlignSizeof<uint32_t, 1>()), 4u);
+    ASSERT_EQ((AlignSizeof<uint32_t, 2>()), 4u);
+    ASSERT_EQ((AlignSizeof<uint32_t, 4>()), 4u);
+
+    ASSERT_EQ((AlignSizeof<uint64_t, 1>()), 8u);
+    ASSERT_EQ((AlignSizeof<uint64_t, 2>()), 8u);
+    ASSERT_EQ((AlignSizeof<uint64_t, 4>()), 8u);
+    ASSERT_EQ((AlignSizeof<uint64_t, 8>()), 8u);
+
+    // Everything in range (align, 2*align] aligns to 2*align.
+    ASSERT_EQ((AlignSizeof<char[5], 4>()), 8u);
+    ASSERT_EQ((AlignSizeof<char[6], 4>()), 8u);
+    ASSERT_EQ((AlignSizeof<char[7], 4>()), 8u);
+    ASSERT_EQ((AlignSizeof<char[8], 4>()), 8u);
+}
+
+TEST(Math, AlignSizeofN) {
+    // Everything in range (align, 2*align] aligns to 2*align.
+    ASSERT_EQ(*(AlignSizeofN<char, 4>(5)), 8u);
+    ASSERT_EQ(*(AlignSizeofN<char, 4>(6)), 8u);
+    ASSERT_EQ(*(AlignSizeofN<char, 4>(7)), 8u);
+    ASSERT_EQ(*(AlignSizeofN<char, 4>(8)), 8u);
+
+    // Extremes should return nullopt.
+    ASSERT_EQ((AlignSizeofN<char, 4>(std::numeric_limits<size_t>::max())), std::nullopt);
+    ASSERT_EQ((AlignSizeofN<char, 4>(std::numeric_limits<uint64_t>::max())), std::nullopt);
+}
+
 // Tests for IsPtrAligned
 TEST(Math, IsPtrAligned) {
     constexpr size_t kTestAlignment = 8;
diff --git a/src/dawn/wire/BufferConsumer.h b/src/dawn/wire/BufferConsumer.h
index 1ae8451..12c4036 100644
--- a/src/dawn/wire/BufferConsumer.h
+++ b/src/dawn/wire/BufferConsumer.h
@@ -17,10 +17,22 @@
 
 #include <cstddef>
 
+#include "dawn/common/Constants.h"
+#include "dawn/common/Math.h"
 #include "dawn/wire/WireResult.h"
 
 namespace dawn::wire {
 
+// Wire specific alignment helpers.
+template <typename T>
+constexpr size_t WireAlignSizeof() {
+    return AlignSizeof<T, kWireBufferAlignment>();
+}
+template <typename T>
+std::optional<size_t> WireAlignSizeofN(size_t n) {
+    return AlignSizeofN<T, kWireBufferAlignment>(n);
+}
+
 // BufferConsumer is a utility class that allows reading bytes from a buffer
 // while simultaneously decrementing the amount of remaining space by exactly
 // the amount read. It helps prevent bugs where incrementing a pointer and
diff --git a/src/dawn/wire/BufferConsumer_impl.h b/src/dawn/wire/BufferConsumer_impl.h
index 6b5d0a1..52b6720 100644
--- a/src/dawn/wire/BufferConsumer_impl.h
+++ b/src/dawn/wire/BufferConsumer_impl.h
@@ -15,11 +15,11 @@
 #ifndef SRC_DAWN_WIRE_BUFFERCONSUMER_IMPL_H_
 #define SRC_DAWN_WIRE_BUFFERCONSUMER_IMPL_H_
 
-#include "dawn/wire/BufferConsumer.h"
-
 #include <limits>
 #include <type_traits>
 
+#include "dawn/wire/BufferConsumer.h"
+
 namespace dawn::wire {
 
 template <typename BufferT>
@@ -36,13 +36,14 @@
 template <typename BufferT>
 template <typename T>
 WireResult BufferConsumer<BufferT>::Next(T** data) {
-    if (sizeof(T) > mSize) {
+    constexpr size_t kSize = WireAlignSizeof<T>();
+    if (kSize > mSize) {
         return WireResult::FatalError;
     }
 
     *data = reinterpret_cast<T*>(mBuffer);
-    mBuffer += sizeof(T);
-    mSize -= sizeof(T);
+    mBuffer += kSize;
+    mSize -= kSize;
     return WireResult::Success;
 }
 
@@ -51,20 +52,15 @@
 WireResult BufferConsumer<BufferT>::NextN(N count, T** data) {
     static_assert(std::is_unsigned<N>::value, "|count| argument of NextN must be unsigned.");
 
-    constexpr size_t kMaxCountWithoutOverflows = std::numeric_limits<size_t>::max() / sizeof(T);
-    if (count > kMaxCountWithoutOverflows) {
-        return WireResult::FatalError;
-    }
-
-    // Cannot overflow because |count| is not greater than |kMaxCountWithoutOverflows|.
-    size_t totalSize = sizeof(T) * count;
-    if (totalSize > mSize) {
+    // If size is zero then it indicates an overflow.
+    auto size = WireAlignSizeofN<T>(count);
+    if (size && *size > mSize) {
         return WireResult::FatalError;
     }
 
     *data = reinterpret_cast<T*>(mBuffer);
-    mBuffer += totalSize;
-    mSize -= totalSize;
+    mBuffer += *size;
+    mSize -= *size;
     return WireResult::Success;
 }
 
diff --git a/src/dawn/wire/ChunkedCommandSerializer.h b/src/dawn/wire/ChunkedCommandSerializer.h
index 7ac72e5..8ad64ff 100644
--- a/src/dawn/wire/ChunkedCommandSerializer.h
+++ b/src/dawn/wire/ChunkedCommandSerializer.h
@@ -17,73 +17,97 @@
 
 #include <algorithm>
 #include <cstring>
+#include <functional>
 #include <memory>
 #include <utility>
 
 #include "dawn/common/Alloc.h"
 #include "dawn/common/Compiler.h"
+#include "dawn/common/Constants.h"
+#include "dawn/common/Math.h"
 #include "dawn/wire/Wire.h"
 #include "dawn/wire/WireCmd_autogen.h"
 
 namespace dawn::wire {
 
+// Simple command extension struct used when a command needs to serialize additional information
+// that is not baked directly into the command already.
+struct CommandExtension {
+    size_t size;
+    std::function<void(char*)> serialize;
+};
+
+namespace detail {
+
+inline WireResult SerializeCommandExtension(SerializeBuffer* serializeBuffer) {
+    return WireResult::Success;
+}
+
+template <typename Extension, typename... Extensions>
+WireResult SerializeCommandExtension(SerializeBuffer* serializeBuffer,
+                                     Extension&& e,
+                                     Extensions&&... es) {
+    char* buffer;
+    WIRE_TRY(serializeBuffer->NextN(e.size, &buffer));
+    e.serialize(buffer);
+
+    WIRE_TRY(SerializeCommandExtension(serializeBuffer, std::forward<Extensions>(es)...));
+    return WireResult::Success;
+}
+
+}  // namespace detail
+
 class ChunkedCommandSerializer {
   public:
     explicit ChunkedCommandSerializer(CommandSerializer* serializer);
 
     template <typename Cmd>
     void SerializeCommand(const Cmd& cmd) {
-        SerializeCommand(cmd, 0, [](SerializeBuffer*) { return WireResult::Success; });
+        SerializeCommandImpl(
+            cmd, [](const Cmd& cmd, size_t requiredSize, SerializeBuffer* serializeBuffer) {
+                return cmd.Serialize(requiredSize, serializeBuffer);
+            });
     }
 
-    template <typename Cmd, typename ExtraSizeSerializeFn>
-    void SerializeCommand(const Cmd& cmd,
-                          size_t extraSize,
-                          ExtraSizeSerializeFn&& SerializeExtraSize) {
+    template <typename Cmd, typename... Extensions>
+    void SerializeCommand(const Cmd& cmd, CommandExtension&& e, Extensions&&... es) {
         SerializeCommandImpl(
             cmd,
             [](const Cmd& cmd, size_t requiredSize, SerializeBuffer* serializeBuffer) {
                 return cmd.Serialize(requiredSize, serializeBuffer);
             },
-            extraSize, std::forward<ExtraSizeSerializeFn>(SerializeExtraSize));
+            std::forward<CommandExtension>(e), std::forward<Extensions>(es)...);
     }
 
-    template <typename Cmd>
-    void SerializeCommand(const Cmd& cmd, const ObjectIdProvider& objectIdProvider) {
-        SerializeCommand(cmd, objectIdProvider, 0,
-                         [](SerializeBuffer*) { return WireResult::Success; });
-    }
-
-    template <typename Cmd, typename ExtraSizeSerializeFn>
+    template <typename Cmd, typename... Extensions>
     void SerializeCommand(const Cmd& cmd,
                           const ObjectIdProvider& objectIdProvider,
-                          size_t extraSize,
-                          ExtraSizeSerializeFn&& SerializeExtraSize) {
+                          Extensions&&... extensions) {
         SerializeCommandImpl(
             cmd,
             [&objectIdProvider](const Cmd& cmd, size_t requiredSize,
                                 SerializeBuffer* serializeBuffer) {
                 return cmd.Serialize(requiredSize, serializeBuffer, objectIdProvider);
             },
-            extraSize, std::forward<ExtraSizeSerializeFn>(SerializeExtraSize));
+            std::forward<Extensions>(extensions)...);
     }
 
   private:
-    template <typename Cmd, typename SerializeCmdFn, typename ExtraSizeSerializeFn>
+    template <typename Cmd, typename SerializeCmdFn, typename... Extensions>
     void SerializeCommandImpl(const Cmd& cmd,
                               SerializeCmdFn&& SerializeCmd,
-                              size_t extraSize,
-                              ExtraSizeSerializeFn&& SerializeExtraSize) {
+                              Extensions&&... extensions) {
         size_t commandSize = cmd.GetRequiredSize();
-        size_t requiredSize = commandSize + extraSize;
+        size_t requiredSize = (Align(extensions.size, kWireBufferAlignment) + ... + commandSize);
 
         if (requiredSize <= mMaxAllocationSize) {
             char* allocatedBuffer = static_cast<char*>(mSerializer->GetCmdSpace(requiredSize));
             if (allocatedBuffer != nullptr) {
                 SerializeBuffer serializeBuffer(allocatedBuffer, requiredSize);
-                WireResult r1 = SerializeCmd(cmd, requiredSize, &serializeBuffer);
-                WireResult r2 = SerializeExtraSize(&serializeBuffer);
-                if (DAWN_UNLIKELY(r1 != WireResult::Success || r2 != WireResult::Success)) {
+                WireResult rCmd = SerializeCmd(cmd, requiredSize, &serializeBuffer);
+                WireResult rExts =
+                    detail::SerializeCommandExtension(&serializeBuffer, extensions...);
+                if (DAWN_UNLIKELY(rCmd != WireResult::Success || rExts != WireResult::Success)) {
                     mSerializer->OnSerializeError();
                 }
             }
@@ -95,9 +119,9 @@
             return;
         }
         SerializeBuffer serializeBuffer(cmdSpace.get(), requiredSize);
-        WireResult r1 = SerializeCmd(cmd, requiredSize, &serializeBuffer);
-        WireResult r2 = SerializeExtraSize(&serializeBuffer);
-        if (DAWN_UNLIKELY(r1 != WireResult::Success || r2 != WireResult::Success)) {
+        WireResult rCmd = SerializeCmd(cmd, requiredSize, &serializeBuffer);
+        WireResult rExts = detail::SerializeCommandExtension(&serializeBuffer, extensions...);
+        if (DAWN_UNLIKELY(rCmd != WireResult::Success || rExts != WireResult::Success)) {
             mSerializer->OnSerializeError();
             return;
         }
diff --git a/src/dawn/wire/client/Buffer.cpp b/src/dawn/wire/client/Buffer.cpp
index 32da663..4315452 100644
--- a/src/dawn/wire/client/Buffer.cpp
+++ b/src/dawn/wire/client/Buffer.cpp
@@ -47,6 +47,8 @@
     cmd.writeHandleCreateInfoLength = 0;
     cmd.writeHandleCreateInfo = nullptr;
 
+    size_t readHandleCreateInfoLength = 0;
+    size_t writeHandleCreateInfoLength = 0;
     if (mappable) {
         if ((descriptor->usage & WGPUBufferUsage_MapRead) != 0) {
             // Create the read handle on buffer creation.
@@ -56,7 +58,8 @@
                 device->InjectError(WGPUErrorType_OutOfMemory, "Failed to create buffer mapping");
                 return CreateError(device, descriptor);
             }
-            cmd.readHandleCreateInfoLength = readHandle->SerializeCreateSize();
+            readHandleCreateInfoLength = readHandle->SerializeCreateSize();
+            cmd.readHandleCreateInfoLength = readHandleCreateInfoLength;
         }
 
         if ((descriptor->usage & WGPUBufferUsage_MapWrite) != 0 || descriptor->mappedAtCreation) {
@@ -67,7 +70,8 @@
                 device->InjectError(WGPUErrorType_OutOfMemory, "Failed to create buffer mapping");
                 return CreateError(device, descriptor);
             }
-            cmd.writeHandleCreateInfoLength = writeHandle->SerializeCreateSize();
+            writeHandleCreateInfoLength = writeHandle->SerializeCreateSize();
+            cmd.writeHandleCreateInfoLength = writeHandleCreateInfoLength;
         }
     }
 
@@ -95,27 +99,28 @@
 
     cmd.result = buffer->GetWireHandle();
 
+    // clang-format off
+    // Turning off clang format here because for some reason it does not format the
+    // CommandExtensions consistently, making it harder to read.
     wireClient->SerializeCommand(
-        cmd, cmd.readHandleCreateInfoLength + cmd.writeHandleCreateInfoLength,
-        [&](SerializeBuffer* serializeBuffer) {
-            if (readHandle != nullptr) {
-                char* readHandleBuffer;
-                WIRE_TRY(serializeBuffer->NextN(cmd.readHandleCreateInfoLength, &readHandleBuffer));
-                // Serialize the ReadHandle into the space after the command.
-                readHandle->SerializeCreate(readHandleBuffer);
-                buffer->mReadHandle = std::move(readHandle);
-            }
-            if (writeHandle != nullptr) {
-                char* writeHandleBuffer;
-                WIRE_TRY(
-                    serializeBuffer->NextN(cmd.writeHandleCreateInfoLength, &writeHandleBuffer));
-                // Serialize the WriteHandle into the space after the command.
-                writeHandle->SerializeCreate(writeHandleBuffer);
-                buffer->mWriteHandle = std::move(writeHandle);
-            }
-
-            return WireResult::Success;
-        });
+        cmd,
+        CommandExtension{readHandleCreateInfoLength,
+                         [&](char* readHandleBuffer) {
+                             if (readHandle != nullptr) {
+                                 // Serialize the ReadHandle into the space after the command.
+                                 readHandle->SerializeCreate(readHandleBuffer);
+                                 buffer->mReadHandle = std::move(readHandle);
+                             }
+                         }},
+        CommandExtension{writeHandleCreateInfoLength,
+                         [&](char* writeHandleBuffer) {
+                             if (writeHandle != nullptr) {
+                                 // Serialize the WriteHandle into the space after the command.
+                                 writeHandle->SerializeCreate(writeHandleBuffer);
+                                 buffer->mWriteHandle = std::move(writeHandle);
+                             }
+                         }});
+    // clang-format on
     return ToAPI(buffer);
 }
 
@@ -310,16 +315,12 @@
         cmd.size = mMapSize;
 
         client->SerializeCommand(
-            cmd, writeDataUpdateInfoLength, [&](SerializeBuffer* serializeBuffer) {
-                char* writeHandleBuffer;
-                WIRE_TRY(serializeBuffer->NextN(writeDataUpdateInfoLength, &writeHandleBuffer));
-
-                // Serialize flush metadata into the space after the command.
-                // This closes the handle for writing.
-                mWriteHandle->SerializeDataUpdate(writeHandleBuffer, cmd.offset, cmd.size);
-
-                return WireResult::Success;
-            });
+            cmd, CommandExtension{writeDataUpdateInfoLength, [&](char* writeHandleBuffer) {
+                                      // Serialize flush metadata into the space after the command.
+                                      // This closes the handle for writing.
+                                      mWriteHandle->SerializeDataUpdate(writeHandleBuffer,
+                                                                        cmd.offset, cmd.size);
+                                  }});
 
         // If mDestructWriteHandleOnUnmap is true, that means the write handle is merely
         // for mappedAtCreation usage. It is destroyed on unmap after flush to server
diff --git a/src/dawn/wire/client/Client.h b/src/dawn/wire/client/Client.h
index 8648522..f16af64 100644
--- a/src/dawn/wire/client/Client.h
+++ b/src/dawn/wire/client/Client.h
@@ -85,11 +85,9 @@
         mSerializer.SerializeCommand(cmd, *this);
     }
 
-    template <typename Cmd, typename ExtraSizeSerializeFn>
-    void SerializeCommand(const Cmd& cmd,
-                          size_t extraSize,
-                          ExtraSizeSerializeFn&& SerializeExtraSize) {
-        mSerializer.SerializeCommand(cmd, *this, extraSize, SerializeExtraSize);
+    template <typename Cmd, typename... Extensions>
+    void SerializeCommand(const Cmd& cmd, Extensions&&... es) {
+        mSerializer.SerializeCommand(cmd, *this, std::forward<Extensions>(es)...);
     }
 
     void Disconnect();
diff --git a/src/dawn/wire/server/Server.h b/src/dawn/wire/server/Server.h
index 2812756..2dfaba7 100644
--- a/src/dawn/wire/server/Server.h
+++ b/src/dawn/wire/server/Server.h
@@ -184,11 +184,9 @@
         mSerializer.SerializeCommand(cmd);
     }
 
-    template <typename Cmd, typename ExtraSizeSerializeFn>
-    void SerializeCommand(const Cmd& cmd,
-                          size_t extraSize,
-                          ExtraSizeSerializeFn&& SerializeExtraSize) {
-        mSerializer.SerializeCommand(cmd, extraSize, SerializeExtraSize);
+    template <typename Cmd, typename... Extensions>
+    void SerializeCommand(const Cmd& cmd, Extensions&&... es) {
+        mSerializer.SerializeCommand(cmd, std::forward<Extensions>(es)...);
     }
 
     void SetForwardingDeviceCallbacks(ObjectData<WGPUDevice>* deviceObject);
diff --git a/src/dawn/wire/server/ServerBuffer.cpp b/src/dawn/wire/server/ServerBuffer.cpp
index f07bdfa..e5208fd 100644
--- a/src/dawn/wire/server/ServerBuffer.cpp
+++ b/src/dawn/wire/server/ServerBuffer.cpp
@@ -237,12 +237,14 @@
     cmd.readDataUpdateInfo = nullptr;
 
     const void* readData = nullptr;
+    size_t readDataUpdateInfoLength = 0;
     if (isSuccess) {
         if (isRead) {
             // Get the serialization size of the message to initialize ReadHandle data.
             readData = mProcs.bufferGetConstMappedRange(data->bufferObj, data->offset, data->size);
-            cmd.readDataUpdateInfoLength =
+            readDataUpdateInfoLength =
                 bufferData->readHandle->SizeOfSerializeDataUpdate(data->offset, data->size);
+            cmd.readDataUpdateInfoLength = readDataUpdateInfoLength;
         } else {
             ASSERT(data->mode & WGPUMapMode_Write);
             // The in-flight map request returned successfully.
@@ -259,16 +261,15 @@
         }
     }
 
-    SerializeCommand(cmd, cmd.readDataUpdateInfoLength, [&](SerializeBuffer* serializeBuffer) {
-        if (isSuccess && isRead) {
-            char* readHandleBuffer;
-            WIRE_TRY(serializeBuffer->NextN(cmd.readDataUpdateInfoLength, &readHandleBuffer));
-            // The in-flight map request returned successfully.
-            bufferData->readHandle->SerializeDataUpdate(readData, data->offset, data->size,
-                                                        readHandleBuffer);
-        }
-        return WireResult::Success;
-    });
+    SerializeCommand(cmd, CommandExtension{readDataUpdateInfoLength, [&](char* readHandleBuffer) {
+                                               if (isSuccess && isRead) {
+                                                   // The in-flight map request returned
+                                                   // successfully.
+                                                   bufferData->readHandle->SerializeDataUpdate(
+                                                       readData, data->offset, data->size,
+                                                       readHandleBuffer);
+                                               }
+                                           }});
 }
 
 }  // namespace dawn::wire::server