MemoryTransferService: Separate functions to serialize and get serialization size

Bug: dawn:156
Change-Id: I19317954c64700bdd67aa414d8eb2422d2c3544d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/9860
Commit-Queue: Kai Ninomiya <kainino@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
diff --git a/src/dawn_wire/client/ApiProcs.cpp b/src/dawn_wire/client/ApiProcs.cpp
index 42a33fc..64c2fe9 100644
--- a/src/dawn_wire/client/ApiProcs.cpp
+++ b/src/dawn_wire/client/ApiProcs.cpp
@@ -27,7 +27,7 @@
                 std::is_same<Handle, MemoryTransferService::WriteHandle>::value;
 
             // Get the serialization size of the handle.
-            size_t handleCreateInfoLength = handle->SerializeCreate();
+            size_t handleCreateInfoLength = handle->SerializeCreateSize();
 
             BufferMapAsyncCmd cmd;
             cmd.bufferId = buffer->id;
@@ -171,7 +171,7 @@
         buffer->writeHandle = std::move(writeHandle);
 
         // Get the serialization size of the WriteHandle.
-        size_t handleCreateInfoLength = buffer->writeHandle->SerializeCreate();
+        size_t handleCreateInfoLength = buffer->writeHandle->SerializeCreateSize();
 
         DeviceCreateBufferMappedCmd cmd;
         cmd.device = cDevice;
@@ -246,7 +246,7 @@
         buffer->requests[serial] = std::move(request);
 
         // Get the serialization size of the WriteHandle.
-        size_t handleCreateInfoLength = writeHandle->SerializeCreate();
+        size_t handleCreateInfoLength = writeHandle->SerializeCreateSize();
 
         DeviceCreateBufferMappedAsyncCmd cmd;
         cmd.device = cDevice;
@@ -326,7 +326,7 @@
             ASSERT(buffer->readHandle == nullptr);
 
             // Get the serialization size of metadata to flush writes.
-            size_t writeFlushInfoLength = buffer->writeHandle->SerializeFlush();
+            size_t writeFlushInfoLength = buffer->writeHandle->SerializeFlushSize();
 
             BufferUpdateMappedDataCmd cmd;
             cmd.bufferId = buffer->id;
diff --git a/src/dawn_wire/client/ClientInlineMemoryTransferService.cpp b/src/dawn_wire/client/ClientInlineMemoryTransferService.cpp
index a8250ab..92542d2 100644
--- a/src/dawn_wire/client/ClientInlineMemoryTransferService.cpp
+++ b/src/dawn_wire/client/ClientInlineMemoryTransferService.cpp
@@ -26,10 +26,13 @@
 
             ~ReadHandleImpl() override = default;
 
-            size_t SerializeCreate(void*) override {
+            size_t SerializeCreateSize() override {
                 return 0;
             }
 
+            void SerializeCreate(void*) override {
+            }
+
             bool DeserializeInitialData(const void* deserializePointer,
                                         size_t deserializeSize,
                                         const void** data,
@@ -61,24 +64,29 @@
 
             ~WriteHandleImpl() override = default;
 
-            size_t SerializeCreate(void*) override {
+            size_t SerializeCreateSize() override {
                 return 0;
             }
 
+            void SerializeCreate(void*) override {
+            }
+
             std::pair<void*, size_t> Open() override {
                 mStagingData = std::unique_ptr<uint8_t[]>(new uint8_t[mSize]);
                 memset(mStagingData.get(), 0, mSize);
                 return std::make_pair(mStagingData.get(), mSize);
             }
 
-            size_t SerializeFlush(void* serializePointer) override {
-                if (serializePointer != nullptr) {
-                    ASSERT(mStagingData != nullptr);
-                    memcpy(serializePointer, mStagingData.get(), mSize);
-                }
+            size_t SerializeFlushSize() override {
                 return mSize;
             }
 
+            void SerializeFlush(void* serializePointer) override {
+                ASSERT(mStagingData != nullptr);
+                ASSERT(serializePointer != nullptr);
+                memcpy(serializePointer, mStagingData.get(), mSize);
+            }
+
           private:
             size_t mSize;
             std::unique_ptr<uint8_t[]> mStagingData;
diff --git a/src/dawn_wire/client/ClientMemoryTransferService_mock.cpp b/src/dawn_wire/client/ClientMemoryTransferService_mock.cpp
index 6762265..dd8d62f 100644
--- a/src/dawn_wire/client/ClientMemoryTransferService_mock.cpp
+++ b/src/dawn_wire/client/ClientMemoryTransferService_mock.cpp
@@ -27,8 +27,12 @@
         mService->OnReadHandleDestroy(this);
     }
 
-    size_t MockMemoryTransferService::MockReadHandle::SerializeCreate(void* serializePointer) {
-        return mService->OnReadHandleSerializeCreate(this, serializePointer);
+    size_t MockMemoryTransferService::MockReadHandle::SerializeCreateSize() {
+        return mService->OnReadHandleSerializeCreateSize(this);
+    }
+
+    void MockMemoryTransferService::MockReadHandle::SerializeCreate(void* serializePointer) {
+        mService->OnReadHandleSerializeCreate(this, serializePointer);
     }
 
     bool MockMemoryTransferService::MockReadHandle::DeserializeInitialData(
@@ -50,16 +54,24 @@
         mService->OnWriteHandleDestroy(this);
     }
 
-    size_t MockMemoryTransferService::MockWriteHandle::SerializeCreate(void* serializePointer) {
-        return mService->OnWriteHandleSerializeCreate(this, serializePointer);
+    size_t MockMemoryTransferService::MockWriteHandle::SerializeCreateSize() {
+        return mService->OnWriteHandleSerializeCreateSize(this);
+    }
+
+    void MockMemoryTransferService::MockWriteHandle::SerializeCreate(void* serializePointer) {
+        mService->OnWriteHandleSerializeCreate(this, serializePointer);
     }
 
     std::pair<void*, size_t> MockMemoryTransferService::MockWriteHandle::Open() {
         return mService->OnWriteHandleOpen(this);
     }
 
-    size_t MockMemoryTransferService::MockWriteHandle::SerializeFlush(void* serializePointer) {
-        return mService->OnWriteHandleSerializeFlush(this, serializePointer);
+    size_t MockMemoryTransferService::MockWriteHandle::SerializeFlushSize() {
+        return mService->OnWriteHandleSerializeFlushSize(this);
+    }
+
+    void MockMemoryTransferService::MockWriteHandle::SerializeFlush(void* serializePointer) {
+        mService->OnWriteHandleSerializeFlush(this, serializePointer);
     }
 
     MockMemoryTransferService::MockMemoryTransferService() = default;
diff --git a/src/dawn_wire/client/ClientMemoryTransferService_mock.h b/src/dawn_wire/client/ClientMemoryTransferService_mock.h
index c54ce23..54c7f7b 100644
--- a/src/dawn_wire/client/ClientMemoryTransferService_mock.h
+++ b/src/dawn_wire/client/ClientMemoryTransferService_mock.h
@@ -29,7 +29,8 @@
             explicit MockReadHandle(MockMemoryTransferService* service);
             ~MockReadHandle() override;
 
-            size_t SerializeCreate(void* serializePointer) override;
+            size_t SerializeCreateSize() override;
+            void SerializeCreate(void* serializePointer) override;
             bool DeserializeInitialData(const void* deserializePointer,
                                         size_t deserializeSize,
                                         const void** data,
@@ -44,9 +45,11 @@
             explicit MockWriteHandle(MockMemoryTransferService* service);
             ~MockWriteHandle() override;
 
-            size_t SerializeCreate(void* serializePointer) override;
+            size_t SerializeCreateSize() override;
+            void SerializeCreate(void* serializePointer) override;
             std::pair<void*, size_t> Open() override;
-            size_t SerializeFlush(void* serializePointer) override;
+            size_t SerializeFlushSize() override;
+            void SerializeFlush(void* serializePointer) override;
 
           private:
             MockMemoryTransferService* mService;
@@ -64,8 +67,8 @@
         MOCK_METHOD1(OnCreateReadHandle, ReadHandle*(size_t));
         MOCK_METHOD1(OnCreateWriteHandle, WriteHandle*(size_t));
 
-        MOCK_METHOD2(OnReadHandleSerializeCreate,
-                     size_t(const ReadHandle*, void* serializePointer));
+        MOCK_METHOD1(OnReadHandleSerializeCreateSize, size_t(const ReadHandle*));
+        MOCK_METHOD2(OnReadHandleSerializeCreate, void(const ReadHandle*, void* serializePointer));
         MOCK_METHOD5(OnReadHandleDeserializeInitialData,
                      bool(const ReadHandle*,
                           const uint32_t* deserializePointer,
@@ -74,11 +77,13 @@
                           size_t* dataLength));
         MOCK_METHOD1(OnReadHandleDestroy, void(const ReadHandle*));
 
+        MOCK_METHOD1(OnWriteHandleSerializeCreateSize, size_t(const void* WriteHandle));
         MOCK_METHOD2(OnWriteHandleSerializeCreate,
-                     size_t(const void* WriteHandle, void* serializePointer));
+                     void(const void* WriteHandle, void* serializePointer));
         MOCK_METHOD1(OnWriteHandleOpen, std::pair<void*, size_t>(const void* WriteHandle));
+        MOCK_METHOD1(OnWriteHandleSerializeFlushSize, size_t(const void* WriteHandle));
         MOCK_METHOD2(OnWriteHandleSerializeFlush,
-                     size_t(const void* WriteHandle, void* serializePointer));
+                     void(const void* WriteHandle, void* serializePointer));
         MOCK_METHOD1(OnWriteHandleDestroy, void(const void* WriteHandle));
     };
 
diff --git a/src/dawn_wire/server/ServerBuffer.cpp b/src/dawn_wire/server/ServerBuffer.cpp
index 2e1e78a..c969beb 100644
--- a/src/dawn_wire/server/ServerBuffer.cpp
+++ b/src/dawn_wire/server/ServerBuffer.cpp
@@ -254,7 +254,7 @@
         size_t initialDataInfoLength = 0;
         if (status == DAWN_BUFFER_MAP_ASYNC_STATUS_SUCCESS) {
             // Get the serialization size of the message to initialize ReadHandle data.
-            initialDataInfoLength = data->readHandle->SerializeInitialData(ptr, dataLength);
+            initialDataInfoLength = data->readHandle->SerializeInitialDataSize(ptr, dataLength);
         } else {
             dataLength = 0;
         }
diff --git a/src/dawn_wire/server/ServerInlineMemoryTransferService.cpp b/src/dawn_wire/server/ServerInlineMemoryTransferService.cpp
index cab0c4c..b512e6f 100644
--- a/src/dawn_wire/server/ServerInlineMemoryTransferService.cpp
+++ b/src/dawn_wire/server/ServerInlineMemoryTransferService.cpp
@@ -26,14 +26,18 @@
             }
             ~ReadHandleImpl() override = default;
 
-            size_t SerializeInitialData(const void* data,
-                                        size_t dataLength,
-                                        void* serializePointer) override {
-                if (serializePointer != nullptr && dataLength > 0) {
+            size_t SerializeInitialDataSize(const void* data, size_t dataLength) override {
+                return dataLength;
+            }
+
+            void SerializeInitialData(const void* data,
+                                      size_t dataLength,
+                                      void* serializePointer) override {
+                if (dataLength > 0) {
                     ASSERT(data != nullptr);
+                    ASSERT(serializePointer != nullptr);
                     memcpy(serializePointer, data, dataLength);
                 }
-                return dataLength;
             }
         };
 
diff --git a/src/dawn_wire/server/ServerMemoryTransferService_mock.cpp b/src/dawn_wire/server/ServerMemoryTransferService_mock.cpp
index 316a991..b8b1696 100644
--- a/src/dawn_wire/server/ServerMemoryTransferService_mock.cpp
+++ b/src/dawn_wire/server/ServerMemoryTransferService_mock.cpp
@@ -26,10 +26,15 @@
         mService->OnReadHandleDestroy(this);
     }
 
-    size_t MockMemoryTransferService::MockReadHandle::SerializeInitialData(const void* data,
-                                                                           size_t dataLength,
-                                                                           void* serializePointer) {
-        return mService->OnReadHandleSerializeInitialData(this, data, dataLength, serializePointer);
+    size_t MockMemoryTransferService::MockReadHandle::SerializeInitialDataSize(const void* data,
+                                                                               size_t dataLength) {
+        return mService->OnReadHandleSerializeInitialDataSize(this, data, dataLength);
+    }
+
+    void MockMemoryTransferService::MockReadHandle::SerializeInitialData(const void* data,
+                                                                         size_t dataLength,
+                                                                         void* serializePointer) {
+        mService->OnReadHandleSerializeInitialData(this, data, dataLength, serializePointer);
     }
 
     MockMemoryTransferService::MockWriteHandle::MockWriteHandle(MockMemoryTransferService* service)
diff --git a/src/dawn_wire/server/ServerMemoryTransferService_mock.h b/src/dawn_wire/server/ServerMemoryTransferService_mock.h
index 5065264..4e1e600 100644
--- a/src/dawn_wire/server/ServerMemoryTransferService_mock.h
+++ b/src/dawn_wire/server/ServerMemoryTransferService_mock.h
@@ -29,9 +29,10 @@
             MockReadHandle(MockMemoryTransferService* service);
             ~MockReadHandle() override;
 
-            size_t SerializeInitialData(const void* data,
-                                        size_t dataLength,
-                                        void* serializePointer) override;
+            size_t SerializeInitialDataSize(const void* data, size_t dataLength) override;
+            void SerializeInitialData(const void* data,
+                                      size_t dataLength,
+                                      void* serializePointer) override;
 
           private:
             MockMemoryTransferService* mService;
@@ -74,11 +75,13 @@
                           size_t deserializeSize,
                           WriteHandle** writeHandle));
 
+        MOCK_METHOD3(OnReadHandleSerializeInitialDataSize,
+                     size_t(const ReadHandle* readHandle, const void* data, size_t dataLength));
         MOCK_METHOD4(OnReadHandleSerializeInitialData,
-                     size_t(const ReadHandle* readHandle,
-                            const void* data,
-                            size_t dataLength,
-                            void* serializePointer));
+                     void(const ReadHandle* readHandle,
+                          const void* data,
+                          size_t dataLength,
+                          void* serializePointer));
         MOCK_METHOD1(OnReadHandleDestroy, void(const ReadHandle* readHandle));
 
         MOCK_METHOD3(OnWriteHandleDeserializeFlush,
diff --git a/src/include/dawn_wire/WireClient.h b/src/include/dawn_wire/WireClient.h
index f3580359..458a593 100644
--- a/src/include/dawn_wire/WireClient.h
+++ b/src/include/dawn_wire/WireClient.h
@@ -76,9 +76,11 @@
 
             class DAWN_WIRE_EXPORT ReadHandle {
               public:
+                // Get the required serialization size for SerializeCreate
+                virtual size_t SerializeCreateSize() = 0;
+
                 // Serialize the handle into |serializePointer| so it can be received by the server.
-                // If |serializePointer| is nullptr, this returns the required serialization space.
-                virtual size_t SerializeCreate(void* serializePointer = nullptr) = 0;
+                virtual void SerializeCreate(void* serializePointer) = 0;
 
                 // Load initial data and open the handle for reading.
                 // This function takes in the serialized result of
@@ -95,19 +97,24 @@
 
             class DAWN_WIRE_EXPORT WriteHandle {
               public:
+                // Get the required serialization size for SerializeCreate
+                virtual size_t SerializeCreateSize() = 0;
+
                 // Serialize the handle into |serializePointer| so it can be received by the server.
-                // If |serializePointer| is nullptr, this returns the required serialization space.
-                virtual size_t SerializeCreate(void* serializePointer = nullptr) = 0;
+                virtual void SerializeCreate(void* serializePointer) = 0;
 
                 // Open the handle for reading. The data returned should be zero-initialized.
                 // The data returned must live at least until the WriteHandle is destructed.
                 // On failure, the pointer returned should be null.
                 virtual std::pair<void*, size_t> Open() = 0;
 
+                // Get the required serialization size for SerializeFlush
+                virtual size_t SerializeFlushSize() = 0;
+
                 // Flush writes to the handle. This should serialize info to send updates to the
                 // server.
-                // If |serializePointer| is nullptr, this returns the required serialization space.
-                virtual size_t SerializeFlush(void* serializePointer = nullptr) = 0;
+                virtual void SerializeFlush(void* serializePointer) = 0;
+
                 virtual ~WriteHandle();
             };
         };
diff --git a/src/include/dawn_wire/WireServer.h b/src/include/dawn_wire/WireServer.h
index 9c3da94..f5ae1dc 100644
--- a/src/include/dawn_wire/WireServer.h
+++ b/src/include/dawn_wire/WireServer.h
@@ -65,12 +65,14 @@
 
             class DAWN_WIRE_EXPORT ReadHandle {
               public:
+                // Get the required serialization size for SerializeInitialData
+                virtual size_t SerializeInitialDataSize(const void* data, size_t dataLength) = 0;
+
                 // Initialize the handle data.
                 // Serialize into |serializePointer| so the client can update handle data.
-                // If |serializePointer| is nullptr, this returns the required serialization space.
-                virtual size_t SerializeInitialData(const void* data,
-                                                    size_t dataLength,
-                                                    void* serializePointer = nullptr) = 0;
+                virtual void SerializeInitialData(const void* data,
+                                                  size_t dataLength,
+                                                  void* serializePointer) = 0;
                 virtual ~ReadHandle();
             };
 
diff --git a/src/tests/unittests/wire/WireMemoryTransferServiceTests.cpp b/src/tests/unittests/wire/WireMemoryTransferServiceTests.cpp
index cdd8419..3feb809 100644
--- a/src/tests/unittests/wire/WireMemoryTransferServiceTests.cpp
+++ b/src/tests/unittests/wire/WireMemoryTransferServiceTests.cpp
@@ -230,8 +230,9 @@
     }
 
     void ExpectReadHandleSerialization(ClientReadHandle* handle) {
+        EXPECT_CALL(clientMemoryTransferService, OnReadHandleSerializeCreateSize(handle))
+            .WillOnce(InvokeWithoutArgs([&]() { return sizeof(mSerializeCreateInfo); }));
         EXPECT_CALL(clientMemoryTransferService, OnReadHandleSerializeCreate(handle, _))
-            .WillOnce(InvokeWithoutArgs([&]() { return sizeof(mSerializeCreateInfo); }))
             .WillOnce(WithArg<1>([&](void* serializePointer) {
                 memcpy(serializePointer, &mSerializeCreateInfo, sizeof(mSerializeCreateInfo));
                 return sizeof(mSerializeCreateInfo);
@@ -261,8 +262,9 @@
     }
 
     void ExpectServerReadHandleInitialize(ServerReadHandle* handle) {
+        EXPECT_CALL(serverMemoryTransferService, OnReadHandleSerializeInitialDataSize(handle, _, _))
+            .WillOnce(InvokeWithoutArgs([&]() { return sizeof(mSerializeInitialDataInfo); }));
         EXPECT_CALL(serverMemoryTransferService, OnReadHandleSerializeInitialData(handle, _, _, _))
-            .WillOnce(InvokeWithoutArgs([&]() { return sizeof(mSerializeInitialDataInfo); }))
             .WillOnce(WithArg<3>([&](void* serializePointer) {
                 memcpy(serializePointer, &mSerializeInitialDataInfo,
                        sizeof(mSerializeInitialDataInfo));
@@ -307,8 +309,9 @@
     }
 
     void ExpectWriteHandleSerialization(ClientWriteHandle* handle) {
+        EXPECT_CALL(clientMemoryTransferService, OnWriteHandleSerializeCreateSize(handle))
+            .WillOnce(InvokeWithoutArgs([&]() { return sizeof(mSerializeCreateInfo); }));
         EXPECT_CALL(clientMemoryTransferService, OnWriteHandleSerializeCreate(handle, _))
-            .WillOnce(InvokeWithoutArgs([&]() { return sizeof(mSerializeCreateInfo); }))
             .WillOnce(WithArg<1>([&](void* serializePointer) {
                 memcpy(serializePointer, &mSerializeCreateInfo, sizeof(mSerializeCreateInfo));
                 return sizeof(mSerializeCreateInfo);
@@ -353,8 +356,9 @@
     }
 
     void ExpectClientWriteHandleSerializeFlush(ClientWriteHandle* handle) {
+        EXPECT_CALL(clientMemoryTransferService, OnWriteHandleSerializeFlushSize(handle))
+            .WillOnce(InvokeWithoutArgs([&]() { return sizeof(mSerializeFlushInfo); }));
         EXPECT_CALL(clientMemoryTransferService, OnWriteHandleSerializeFlush(handle, _))
-            .WillOnce(InvokeWithoutArgs([&]() { return sizeof(mSerializeFlushInfo); }))
             .WillOnce(WithArg<1>([&](void* serializePointer) {
                 memcpy(serializePointer, &mSerializeFlushInfo, sizeof(mSerializeFlushInfo));
                 return sizeof(mSerializeFlushInfo);