Add wire serialized command buffer padding.
Pads serialized wire command buffers to 8 bytes so that we don't have
misaligned write/reads which can cause SIGILL depending on platform and
compilation mode, i.e. -c dbg in google3 builds.
- Adds helpers for aligning sizeof calls.
- Adds constant for wire padding (8u).
- Modifies BufferConsumer to allocate according to padding. This
guarantees that when we [de]serialize stuff, the padding should be
equal on both sides.
- Modifies extra byte serialization code (adding CommandExtension
struct). This makes it clearer that each extension needs to be
padded independently. Otherwise, before in wire/client/Buffer.cpp,
since the read/write handle sizes were being passed as a sum, but
read out separately from the BufferConsumer, we corrupt our pointers.
- Adds some simple unit tests.
Bug: dawn:1334
Change-Id: Id80e7c01a34b9f01c3f02b3e6c04c3bb3ad0eff9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/110501
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Loko Kung <lokokung@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/generator/templates/dawn/wire/WireCmd.cpp b/generator/templates/dawn/wire/WireCmd.cpp
index 26c0291..10a7d36 100644
--- a/generator/templates/dawn/wire/WireCmd.cpp
+++ b/generator/templates/dawn/wire/WireCmd.cpp
@@ -165,7 +165,7 @@
if (has_{{memberName}})
{% endif %}
{
- result += std::strlen(record.{{memberName}});
+ result += Align(std::strlen(record.{{memberName}}), kWireBufferAlignment);
}
{% endfor %}
@@ -178,7 +178,9 @@
{% if member.annotation != "value" %}
{{ assert(member.annotation != "const*const*") }}
auto memberLength = {{member_length(member, "record.")}};
- result += memberLength * {{member_transfer_sizeof(member)}};
+ auto size = WireAlignSizeofN<{{member_transfer_type(member)}}>(memberLength);
+ ASSERT(size);
+ result += *size;
//* Structures might contain more pointers so we need to add their extra size as well.
{% if member.type.category == "structure" %}
for (decltype(memberLength) i = 0; i < memberLength; ++i) {
@@ -431,7 +433,7 @@
{% set Cmd = Name + "Cmd" %}
size_t {{Cmd}}::GetRequiredSize() const {
- size_t size = sizeof({{Name}}Transfer) + {{Name}}GetExtraRequiredSize(*this);
+ size_t size = WireAlignSizeof<{{Name}}Transfer>() + {{Name}}GetExtraRequiredSize(*this);
return size;
}
@@ -509,7 +511,7 @@
) %}
case {{as_cEnum(types["s type"].name, sType.name)}}: {
const auto& typedStruct = *reinterpret_cast<{{as_cType(sType.name)}} const *>(chainedStruct);
- result += sizeof({{as_cType(sType.name)}}Transfer);
+ result += WireAlignSizeof<{{as_cType(sType.name)}}Transfer>();
result += {{as_cType(sType.name)}}GetExtraRequiredSize(typedStruct);
chainedStruct = typedStruct.chain.next;
break;
@@ -519,7 +521,7 @@
case WGPUSType_Invalid:
default:
// Invalid enum. Reserve space just for the transfer header (sType and hasNext).
- result += sizeof(WGPUChainedStructTransfer);
+ result += WireAlignSizeof<WGPUChainedStructTransfer>();
chainedStruct = chainedStruct->next;
break;
}
@@ -600,7 +602,7 @@
WIRE_TRY(deserializeBuffer->Read(&transfer));
{{CType}}* outStruct;
- WIRE_TRY(GetSpace(allocator, sizeof({{CType}}), &outStruct));
+ WIRE_TRY(GetSpace(allocator, 1u, &outStruct));
outStruct->chain.sType = sType;
outStruct->chain.next = nullptr;
@@ -629,7 +631,7 @@
WIRE_TRY(deserializeBuffer->Read(&transfer));
{{ChainedStruct}}* outStruct;
- WIRE_TRY(GetSpace(allocator, sizeof({{ChainedStruct}}), &outStruct));
+ WIRE_TRY(GetSpace(allocator, 1u, &outStruct));
outStruct->sType = WGPUSType_Invalid;
outStruct->next = nullptr;
@@ -654,13 +656,23 @@
// Always writes to |out| on success.
template <typename T, typename N>
WireResult GetSpace(DeserializeAllocator* allocator, N count, T** out) {
- constexpr size_t kMaxCountWithoutOverflows = std::numeric_limits<size_t>::max() / sizeof(T);
- if (count > kMaxCountWithoutOverflows) {
+ // Because we use this function extensively when `count` == 1, we can optimize the
+ // size computations a bit more for those cases via constexpr version of the
+ // alignment computation.
+ constexpr size_t kSizeofT = WireAlignSizeof<T>();
+ size_t size = 0;
+ if (count == 1) {
+ size = kSizeofT;
+ } else {
+ auto sizeN = WireAlignSizeofN<T>(count);
+ // A size of 0 indicates an overflow, so return an error.
+ if (!sizeN) {
return WireResult::FatalError;
+ }
+ size = *sizeN;
}
- size_t totalSize = sizeof(T) * count;
- *out = static_cast<T*>(allocator->GetSpace(totalSize));
+ *out = static_cast<T*>(allocator->GetSpace(size));
if (*out == nullptr) {
return WireResult::FatalError;
}
diff --git a/src/dawn/common/Constants.h b/src/dawn/common/Constants.h
index bf1f132..027b6c0 100644
--- a/src/dawn/common/Constants.h
+++ b/src/dawn/common/Constants.h
@@ -15,6 +15,7 @@
#ifndef SRC_DAWN_COMMON_CONSTANTS_H_
#define SRC_DAWN_COMMON_CONSTANTS_H_
+#include <cstddef>
#include <cstdint>
static constexpr uint32_t kMaxBindGroups = 4u;
@@ -65,4 +66,7 @@
static constexpr uint8_t kSamplersPerExternalTexture = 1u;
static constexpr uint8_t kUniformsPerExternalTexture = 1u;
+// Wire buffer alignments.
+static constexpr size_t kWireBufferAlignment = 8u;
+
#endif // SRC_DAWN_COMMON_CONSTANTS_H_
diff --git a/src/dawn/common/Math.h b/src/dawn/common/Math.h
index 9984c4b..c8b518f 100644
--- a/src/dawn/common/Math.h
+++ b/src/dawn/common/Math.h
@@ -20,6 +20,7 @@
#include <cstring>
#include <limits>
+#include <optional>
#include <type_traits>
#include "dawn/common/Assert.h"
@@ -61,6 +62,26 @@
return (value + (alignmentT - 1)) & ~(alignmentT - 1);
}
+template <typename T, size_t Alignment>
+constexpr size_t AlignSizeof() {
+ static_assert(Alignment != 0 && (Alignment & (Alignment - 1)) == 0,
+ "Alignment must be a valid power of 2.");
+ static_assert(sizeof(T) <= std::numeric_limits<size_t>::max() - (Alignment - 1));
+ return (sizeof(T) + (Alignment - 1)) & ~(Alignment - 1);
+}
+
+// Returns an aligned size for an n-sized array of T elements. If the size would overflow, returns
+// nullopt instead.
+template <typename T, size_t Alignment>
+std::optional<size_t> AlignSizeofN(uint64_t n) {
+ constexpr uint64_t kMaxCountWithoutOverflows =
+ (std::numeric_limits<size_t>::max() - Alignment + 1) / sizeof(T);
+ if (n > kMaxCountWithoutOverflows) {
+ return std::nullopt;
+ }
+ return Align(sizeof(T) * n, Alignment);
+}
+
template <typename T>
DAWN_FORCE_INLINE T* AlignPtr(T* ptr, size_t alignment) {
ASSERT(IsPowerOfTwo(alignment));
diff --git a/src/dawn/tests/unittests/MathTests.cpp b/src/dawn/tests/unittests/MathTests.cpp
index d88e858..a38fb67 100644
--- a/src/dawn/tests/unittests/MathTests.cpp
+++ b/src/dawn/tests/unittests/MathTests.cpp
@@ -13,6 +13,7 @@
// limitations under the License.
#include <cmath>
+#include <limits>
#include <vector>
#include "dawn/EnumClassBitmasks.h"
@@ -180,6 +181,41 @@
ASSERT_EQ(Align(static_cast<uint64_t>(0xFFFFFFFFFFFFFFFF), 1), 0xFFFFFFFFFFFFFFFFull);
}
+TEST(Math, AlignSizeof) {
+ // Basic types should align to self if alignment is a divisor.
+ ASSERT_EQ((AlignSizeof<uint8_t, 1>()), 1u);
+
+ ASSERT_EQ((AlignSizeof<uint16_t, 1>()), 2u);
+ ASSERT_EQ((AlignSizeof<uint16_t, 2>()), 2u);
+
+ ASSERT_EQ((AlignSizeof<uint32_t, 1>()), 4u);
+ ASSERT_EQ((AlignSizeof<uint32_t, 2>()), 4u);
+ ASSERT_EQ((AlignSizeof<uint32_t, 4>()), 4u);
+
+ ASSERT_EQ((AlignSizeof<uint64_t, 1>()), 8u);
+ ASSERT_EQ((AlignSizeof<uint64_t, 2>()), 8u);
+ ASSERT_EQ((AlignSizeof<uint64_t, 4>()), 8u);
+ ASSERT_EQ((AlignSizeof<uint64_t, 8>()), 8u);
+
+ // Everything in range (align, 2*align] aligns to 2*align.
+ ASSERT_EQ((AlignSizeof<char[5], 4>()), 8u);
+ ASSERT_EQ((AlignSizeof<char[6], 4>()), 8u);
+ ASSERT_EQ((AlignSizeof<char[7], 4>()), 8u);
+ ASSERT_EQ((AlignSizeof<char[8], 4>()), 8u);
+}
+
+TEST(Math, AlignSizeofN) {
+ // Everything in range (align, 2*align] aligns to 2*align.
+ ASSERT_EQ(*(AlignSizeofN<char, 4>(5)), 8u);
+ ASSERT_EQ(*(AlignSizeofN<char, 4>(6)), 8u);
+ ASSERT_EQ(*(AlignSizeofN<char, 4>(7)), 8u);
+ ASSERT_EQ(*(AlignSizeofN<char, 4>(8)), 8u);
+
+ // Extremes should return nullopt.
+ ASSERT_EQ((AlignSizeofN<char, 4>(std::numeric_limits<size_t>::max())), std::nullopt);
+ ASSERT_EQ((AlignSizeofN<char, 4>(std::numeric_limits<uint64_t>::max())), std::nullopt);
+}
+
// Tests for IsPtrAligned
TEST(Math, IsPtrAligned) {
constexpr size_t kTestAlignment = 8;
diff --git a/src/dawn/wire/BufferConsumer.h b/src/dawn/wire/BufferConsumer.h
index 1ae8451..12c4036 100644
--- a/src/dawn/wire/BufferConsumer.h
+++ b/src/dawn/wire/BufferConsumer.h
@@ -17,10 +17,22 @@
#include <cstddef>
+#include "dawn/common/Constants.h"
+#include "dawn/common/Math.h"
#include "dawn/wire/WireResult.h"
namespace dawn::wire {
+// Wire specific alignment helpers.
+template <typename T>
+constexpr size_t WireAlignSizeof() {
+ return AlignSizeof<T, kWireBufferAlignment>();
+}
+template <typename T>
+std::optional<size_t> WireAlignSizeofN(size_t n) {
+ return AlignSizeofN<T, kWireBufferAlignment>(n);
+}
+
// BufferConsumer is a utility class that allows reading bytes from a buffer
// while simultaneously decrementing the amount of remaining space by exactly
// the amount read. It helps prevent bugs where incrementing a pointer and
diff --git a/src/dawn/wire/BufferConsumer_impl.h b/src/dawn/wire/BufferConsumer_impl.h
index 6b5d0a1..52b6720 100644
--- a/src/dawn/wire/BufferConsumer_impl.h
+++ b/src/dawn/wire/BufferConsumer_impl.h
@@ -15,11 +15,11 @@
#ifndef SRC_DAWN_WIRE_BUFFERCONSUMER_IMPL_H_
#define SRC_DAWN_WIRE_BUFFERCONSUMER_IMPL_H_
-#include "dawn/wire/BufferConsumer.h"
-
#include <limits>
#include <type_traits>
+#include "dawn/wire/BufferConsumer.h"
+
namespace dawn::wire {
template <typename BufferT>
@@ -36,13 +36,14 @@
template <typename BufferT>
template <typename T>
WireResult BufferConsumer<BufferT>::Next(T** data) {
- if (sizeof(T) > mSize) {
+ constexpr size_t kSize = WireAlignSizeof<T>();
+ if (kSize > mSize) {
return WireResult::FatalError;
}
*data = reinterpret_cast<T*>(mBuffer);
- mBuffer += sizeof(T);
- mSize -= sizeof(T);
+ mBuffer += kSize;
+ mSize -= kSize;
return WireResult::Success;
}
@@ -51,20 +52,15 @@
WireResult BufferConsumer<BufferT>::NextN(N count, T** data) {
static_assert(std::is_unsigned<N>::value, "|count| argument of NextN must be unsigned.");
- constexpr size_t kMaxCountWithoutOverflows = std::numeric_limits<size_t>::max() / sizeof(T);
- if (count > kMaxCountWithoutOverflows) {
- return WireResult::FatalError;
- }
-
- // Cannot overflow because |count| is not greater than |kMaxCountWithoutOverflows|.
- size_t totalSize = sizeof(T) * count;
- if (totalSize > mSize) {
+ // If size is zero then it indicates an overflow.
+ auto size = WireAlignSizeofN<T>(count);
+ if (size && *size > mSize) {
return WireResult::FatalError;
}
*data = reinterpret_cast<T*>(mBuffer);
- mBuffer += totalSize;
- mSize -= totalSize;
+ mBuffer += *size;
+ mSize -= *size;
return WireResult::Success;
}
diff --git a/src/dawn/wire/ChunkedCommandSerializer.h b/src/dawn/wire/ChunkedCommandSerializer.h
index 7ac72e5..8ad64ff 100644
--- a/src/dawn/wire/ChunkedCommandSerializer.h
+++ b/src/dawn/wire/ChunkedCommandSerializer.h
@@ -17,73 +17,97 @@
#include <algorithm>
#include <cstring>
+#include <functional>
#include <memory>
#include <utility>
#include "dawn/common/Alloc.h"
#include "dawn/common/Compiler.h"
+#include "dawn/common/Constants.h"
+#include "dawn/common/Math.h"
#include "dawn/wire/Wire.h"
#include "dawn/wire/WireCmd_autogen.h"
namespace dawn::wire {
+// Simple command extension struct used when a command needs to serialize additional information
+// that is not baked directly into the command already.
+struct CommandExtension {
+ size_t size;
+ std::function<void(char*)> serialize;
+};
+
+namespace detail {
+
+inline WireResult SerializeCommandExtension(SerializeBuffer* serializeBuffer) {
+ return WireResult::Success;
+}
+
+template <typename Extension, typename... Extensions>
+WireResult SerializeCommandExtension(SerializeBuffer* serializeBuffer,
+ Extension&& e,
+ Extensions&&... es) {
+ char* buffer;
+ WIRE_TRY(serializeBuffer->NextN(e.size, &buffer));
+ e.serialize(buffer);
+
+ WIRE_TRY(SerializeCommandExtension(serializeBuffer, std::forward<Extensions>(es)...));
+ return WireResult::Success;
+}
+
+} // namespace detail
+
class ChunkedCommandSerializer {
public:
explicit ChunkedCommandSerializer(CommandSerializer* serializer);
template <typename Cmd>
void SerializeCommand(const Cmd& cmd) {
- SerializeCommand(cmd, 0, [](SerializeBuffer*) { return WireResult::Success; });
+ SerializeCommandImpl(
+ cmd, [](const Cmd& cmd, size_t requiredSize, SerializeBuffer* serializeBuffer) {
+ return cmd.Serialize(requiredSize, serializeBuffer);
+ });
}
- template <typename Cmd, typename ExtraSizeSerializeFn>
- void SerializeCommand(const Cmd& cmd,
- size_t extraSize,
- ExtraSizeSerializeFn&& SerializeExtraSize) {
+ template <typename Cmd, typename... Extensions>
+ void SerializeCommand(const Cmd& cmd, CommandExtension&& e, Extensions&&... es) {
SerializeCommandImpl(
cmd,
[](const Cmd& cmd, size_t requiredSize, SerializeBuffer* serializeBuffer) {
return cmd.Serialize(requiredSize, serializeBuffer);
},
- extraSize, std::forward<ExtraSizeSerializeFn>(SerializeExtraSize));
+ std::forward<CommandExtension>(e), std::forward<Extensions>(es)...);
}
- template <typename Cmd>
- void SerializeCommand(const Cmd& cmd, const ObjectIdProvider& objectIdProvider) {
- SerializeCommand(cmd, objectIdProvider, 0,
- [](SerializeBuffer*) { return WireResult::Success; });
- }
-
- template <typename Cmd, typename ExtraSizeSerializeFn>
+ template <typename Cmd, typename... Extensions>
void SerializeCommand(const Cmd& cmd,
const ObjectIdProvider& objectIdProvider,
- size_t extraSize,
- ExtraSizeSerializeFn&& SerializeExtraSize) {
+ Extensions&&... extensions) {
SerializeCommandImpl(
cmd,
[&objectIdProvider](const Cmd& cmd, size_t requiredSize,
SerializeBuffer* serializeBuffer) {
return cmd.Serialize(requiredSize, serializeBuffer, objectIdProvider);
},
- extraSize, std::forward<ExtraSizeSerializeFn>(SerializeExtraSize));
+ std::forward<Extensions>(extensions)...);
}
private:
- template <typename Cmd, typename SerializeCmdFn, typename ExtraSizeSerializeFn>
+ template <typename Cmd, typename SerializeCmdFn, typename... Extensions>
void SerializeCommandImpl(const Cmd& cmd,
SerializeCmdFn&& SerializeCmd,
- size_t extraSize,
- ExtraSizeSerializeFn&& SerializeExtraSize) {
+ Extensions&&... extensions) {
size_t commandSize = cmd.GetRequiredSize();
- size_t requiredSize = commandSize + extraSize;
+ size_t requiredSize = (Align(extensions.size, kWireBufferAlignment) + ... + commandSize);
if (requiredSize <= mMaxAllocationSize) {
char* allocatedBuffer = static_cast<char*>(mSerializer->GetCmdSpace(requiredSize));
if (allocatedBuffer != nullptr) {
SerializeBuffer serializeBuffer(allocatedBuffer, requiredSize);
- WireResult r1 = SerializeCmd(cmd, requiredSize, &serializeBuffer);
- WireResult r2 = SerializeExtraSize(&serializeBuffer);
- if (DAWN_UNLIKELY(r1 != WireResult::Success || r2 != WireResult::Success)) {
+ WireResult rCmd = SerializeCmd(cmd, requiredSize, &serializeBuffer);
+ WireResult rExts =
+ detail::SerializeCommandExtension(&serializeBuffer, extensions...);
+ if (DAWN_UNLIKELY(rCmd != WireResult::Success || rExts != WireResult::Success)) {
mSerializer->OnSerializeError();
}
}
@@ -95,9 +119,9 @@
return;
}
SerializeBuffer serializeBuffer(cmdSpace.get(), requiredSize);
- WireResult r1 = SerializeCmd(cmd, requiredSize, &serializeBuffer);
- WireResult r2 = SerializeExtraSize(&serializeBuffer);
- if (DAWN_UNLIKELY(r1 != WireResult::Success || r2 != WireResult::Success)) {
+ WireResult rCmd = SerializeCmd(cmd, requiredSize, &serializeBuffer);
+ WireResult rExts = detail::SerializeCommandExtension(&serializeBuffer, extensions...);
+ if (DAWN_UNLIKELY(rCmd != WireResult::Success || rExts != WireResult::Success)) {
mSerializer->OnSerializeError();
return;
}
diff --git a/src/dawn/wire/client/Buffer.cpp b/src/dawn/wire/client/Buffer.cpp
index 32da663..4315452 100644
--- a/src/dawn/wire/client/Buffer.cpp
+++ b/src/dawn/wire/client/Buffer.cpp
@@ -47,6 +47,8 @@
cmd.writeHandleCreateInfoLength = 0;
cmd.writeHandleCreateInfo = nullptr;
+ size_t readHandleCreateInfoLength = 0;
+ size_t writeHandleCreateInfoLength = 0;
if (mappable) {
if ((descriptor->usage & WGPUBufferUsage_MapRead) != 0) {
// Create the read handle on buffer creation.
@@ -56,7 +58,8 @@
device->InjectError(WGPUErrorType_OutOfMemory, "Failed to create buffer mapping");
return CreateError(device, descriptor);
}
- cmd.readHandleCreateInfoLength = readHandle->SerializeCreateSize();
+ readHandleCreateInfoLength = readHandle->SerializeCreateSize();
+ cmd.readHandleCreateInfoLength = readHandleCreateInfoLength;
}
if ((descriptor->usage & WGPUBufferUsage_MapWrite) != 0 || descriptor->mappedAtCreation) {
@@ -67,7 +70,8 @@
device->InjectError(WGPUErrorType_OutOfMemory, "Failed to create buffer mapping");
return CreateError(device, descriptor);
}
- cmd.writeHandleCreateInfoLength = writeHandle->SerializeCreateSize();
+ writeHandleCreateInfoLength = writeHandle->SerializeCreateSize();
+ cmd.writeHandleCreateInfoLength = writeHandleCreateInfoLength;
}
}
@@ -95,27 +99,28 @@
cmd.result = buffer->GetWireHandle();
+ // clang-format off
+ // Turning off clang format here because for some reason it does not format the
+ // CommandExtensions consistently, making it harder to read.
wireClient->SerializeCommand(
- cmd, cmd.readHandleCreateInfoLength + cmd.writeHandleCreateInfoLength,
- [&](SerializeBuffer* serializeBuffer) {
- if (readHandle != nullptr) {
- char* readHandleBuffer;
- WIRE_TRY(serializeBuffer->NextN(cmd.readHandleCreateInfoLength, &readHandleBuffer));
- // Serialize the ReadHandle into the space after the command.
- readHandle->SerializeCreate(readHandleBuffer);
- buffer->mReadHandle = std::move(readHandle);
- }
- if (writeHandle != nullptr) {
- char* writeHandleBuffer;
- WIRE_TRY(
- serializeBuffer->NextN(cmd.writeHandleCreateInfoLength, &writeHandleBuffer));
- // Serialize the WriteHandle into the space after the command.
- writeHandle->SerializeCreate(writeHandleBuffer);
- buffer->mWriteHandle = std::move(writeHandle);
- }
-
- return WireResult::Success;
- });
+ cmd,
+ CommandExtension{readHandleCreateInfoLength,
+ [&](char* readHandleBuffer) {
+ if (readHandle != nullptr) {
+ // Serialize the ReadHandle into the space after the command.
+ readHandle->SerializeCreate(readHandleBuffer);
+ buffer->mReadHandle = std::move(readHandle);
+ }
+ }},
+ CommandExtension{writeHandleCreateInfoLength,
+ [&](char* writeHandleBuffer) {
+ if (writeHandle != nullptr) {
+ // Serialize the WriteHandle into the space after the command.
+ writeHandle->SerializeCreate(writeHandleBuffer);
+ buffer->mWriteHandle = std::move(writeHandle);
+ }
+ }});
+ // clang-format on
return ToAPI(buffer);
}
@@ -310,16 +315,12 @@
cmd.size = mMapSize;
client->SerializeCommand(
- cmd, writeDataUpdateInfoLength, [&](SerializeBuffer* serializeBuffer) {
- char* writeHandleBuffer;
- WIRE_TRY(serializeBuffer->NextN(writeDataUpdateInfoLength, &writeHandleBuffer));
-
- // Serialize flush metadata into the space after the command.
- // This closes the handle for writing.
- mWriteHandle->SerializeDataUpdate(writeHandleBuffer, cmd.offset, cmd.size);
-
- return WireResult::Success;
- });
+ cmd, CommandExtension{writeDataUpdateInfoLength, [&](char* writeHandleBuffer) {
+ // Serialize flush metadata into the space after the command.
+ // This closes the handle for writing.
+ mWriteHandle->SerializeDataUpdate(writeHandleBuffer,
+ cmd.offset, cmd.size);
+ }});
// If mDestructWriteHandleOnUnmap is true, that means the write handle is merely
// for mappedAtCreation usage. It is destroyed on unmap after flush to server
diff --git a/src/dawn/wire/client/Client.h b/src/dawn/wire/client/Client.h
index 8648522..f16af64 100644
--- a/src/dawn/wire/client/Client.h
+++ b/src/dawn/wire/client/Client.h
@@ -85,11 +85,9 @@
mSerializer.SerializeCommand(cmd, *this);
}
- template <typename Cmd, typename ExtraSizeSerializeFn>
- void SerializeCommand(const Cmd& cmd,
- size_t extraSize,
- ExtraSizeSerializeFn&& SerializeExtraSize) {
- mSerializer.SerializeCommand(cmd, *this, extraSize, SerializeExtraSize);
+ template <typename Cmd, typename... Extensions>
+ void SerializeCommand(const Cmd& cmd, Extensions&&... es) {
+ mSerializer.SerializeCommand(cmd, *this, std::forward<Extensions>(es)...);
}
void Disconnect();
diff --git a/src/dawn/wire/server/Server.h b/src/dawn/wire/server/Server.h
index 2812756..2dfaba7 100644
--- a/src/dawn/wire/server/Server.h
+++ b/src/dawn/wire/server/Server.h
@@ -184,11 +184,9 @@
mSerializer.SerializeCommand(cmd);
}
- template <typename Cmd, typename ExtraSizeSerializeFn>
- void SerializeCommand(const Cmd& cmd,
- size_t extraSize,
- ExtraSizeSerializeFn&& SerializeExtraSize) {
- mSerializer.SerializeCommand(cmd, extraSize, SerializeExtraSize);
+ template <typename Cmd, typename... Extensions>
+ void SerializeCommand(const Cmd& cmd, Extensions&&... es) {
+ mSerializer.SerializeCommand(cmd, std::forward<Extensions>(es)...);
}
void SetForwardingDeviceCallbacks(ObjectData<WGPUDevice>* deviceObject);
diff --git a/src/dawn/wire/server/ServerBuffer.cpp b/src/dawn/wire/server/ServerBuffer.cpp
index f07bdfa..e5208fd 100644
--- a/src/dawn/wire/server/ServerBuffer.cpp
+++ b/src/dawn/wire/server/ServerBuffer.cpp
@@ -237,12 +237,14 @@
cmd.readDataUpdateInfo = nullptr;
const void* readData = nullptr;
+ size_t readDataUpdateInfoLength = 0;
if (isSuccess) {
if (isRead) {
// Get the serialization size of the message to initialize ReadHandle data.
readData = mProcs.bufferGetConstMappedRange(data->bufferObj, data->offset, data->size);
- cmd.readDataUpdateInfoLength =
+ readDataUpdateInfoLength =
bufferData->readHandle->SizeOfSerializeDataUpdate(data->offset, data->size);
+ cmd.readDataUpdateInfoLength = readDataUpdateInfoLength;
} else {
ASSERT(data->mode & WGPUMapMode_Write);
// The in-flight map request returned successfully.
@@ -259,16 +261,15 @@
}
}
- SerializeCommand(cmd, cmd.readDataUpdateInfoLength, [&](SerializeBuffer* serializeBuffer) {
- if (isSuccess && isRead) {
- char* readHandleBuffer;
- WIRE_TRY(serializeBuffer->NextN(cmd.readDataUpdateInfoLength, &readHandleBuffer));
- // The in-flight map request returned successfully.
- bufferData->readHandle->SerializeDataUpdate(readData, data->offset, data->size,
- readHandleBuffer);
- }
- return WireResult::Success;
- });
+ SerializeCommand(cmd, CommandExtension{readDataUpdateInfoLength, [&](char* readHandleBuffer) {
+ if (isSuccess && isRead) {
+ // The in-flight map request returned
+ // successfully.
+ bufferData->readHandle->SerializeDataUpdate(
+ readData, data->offset, data->size,
+ readHandleBuffer);
+ }
+ }});
}
} // namespace dawn::wire::server