[dawn][wire] Check that buffer is mapped in DeserializeDataUpdate. Previously the target of the WriteHandle for a buffer was set as soon as the buffer is mapped. Between the time it was first mapped and the time DeserializeDataUpdate was called (right before Unmap), the buffer could be implicitly unmapped by a call to Device::Destroy. - Instead check for the buffer being mapped directly in DeserializeDataUpdate, which remove the need to track a mapWriteState on the ObjectData<WGPUBuffer>. - Update the change detecting WireTests to account to GetMappedRange being done in a different place now for writable buffers. - Add a new test that allows injecting WireCmds directly for even more precise but even more change detecting tests. - Add necessary backdoors to WireClient and WireTest need for the new tests. - Link dawn::wire statically in dawn_unittests as we now need to use some of its internals directly. Bug: 492139412 Change-Id: Ibe9ab95ae7456c6629434d4978f439ebfe41c4d1 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/296817 Reviewed-by: Loko Kung <lokokung@google.com> Commit-Queue: Corentin Wallez <cwallez@chromium.org>
diff --git a/include/dawn/wire/WireClient.h b/include/dawn/wire/WireClient.h index 3b935ce..6472e40 100644 --- a/include/dawn/wire/WireClient.h +++ b/include/dawn/wire/WireClient.h
@@ -96,6 +96,8 @@ // Commands allocated after this point will not be sent. void Disconnect(); + client::Client* GetImplForTesting(); + private: std::unique_ptr<client::Client> mImpl; };
diff --git a/src/dawn/tests/BUILD.gn b/src/dawn/tests/BUILD.gn index 04900b6..ed0fc2e 100644 --- a/src/dawn/tests/BUILD.gn +++ b/src/dawn/tests/BUILD.gn
@@ -468,6 +468,7 @@ "unittests/wire/WireOptionalTests.cpp", "unittests/wire/WireQueueTests.cpp", "unittests/wire/WireShaderModuleTests.cpp", + "unittests/wire/WireSpecificCommandTests.cpp", "unittests/wire/WireTest.cpp", "unittests/wire/WireTest.h", ]
diff --git a/src/dawn/tests/unittests/wire/WireBufferMappingTests.cpp b/src/dawn/tests/unittests/wire/WireBufferMappingTests.cpp index 0346239..5fbfa97 100644 --- a/src/dawn/tests/unittests/wire/WireBufferMappingTests.cpp +++ b/src/dawn/tests/unittests/wire/WireBufferMappingTests.cpp
@@ -123,14 +123,14 @@ } // Sets up the correct mapped range call expectations given the map mode. - void ExpectMappedRangeCall(uint64_t bufferSize, void* bufferContent) { + void ExpectMappedRangeCall() { wgpu::MapMode mapMode = GetMapMode(); if (mapMode == wgpu::MapMode::Read) { - EXPECT_CALL(api, BufferGetConstMappedRange(apiBuffer, 0, bufferSize)) - .WillOnce(Return(bufferContent)); + EXPECT_CALL(api, BufferGetConstMappedRange(apiBuffer, 0, kBufferSize)) + .WillOnce(Return(&mappedBufferContents)); } else if (mapMode == wgpu::MapMode::Write) { - EXPECT_CALL(api, BufferGetMappedRange(apiBuffer, 0, bufferSize)) - .WillOnce(Return(bufferContent)); + EXPECT_CALL(api, BufferGetMappedRange(apiBuffer, 0, kBufferSize)) + .WillOnce(Return(&mappedBufferContents)); } } @@ -144,14 +144,15 @@ wgpu::MapMode mapMode = GetMapMode(); MapAsync(mapMode, 0, kBufferSize); - uint32_t bufferContent = 31337; EXPECT_CALL( api, OnBufferMapAsync(apiBuffer, static_cast<WGPUMapMode>(mapMode), 0, kBufferSize, _)) .WillOnce(InvokeWithoutArgs([&] { api.CallBufferMapAsyncCallback(apiBuffer, WGPUMapAsyncStatus_Success, kEmptyOutputStringView); })); - ExpectMappedRangeCall(kBufferSize, &bufferContent); + if (mapMode & wgpu::MapMode::Read) { + ExpectMappedRangeCall(); + } addExpectations(); // The callback should get called with the expected status, regardless if the server has @@ -230,14 +231,15 @@ wgpu::MapMode mapMode = GetMapMode(); MapAsync(mapMode, 0, kBufferSize); - uint32_t bufferContent = 31337; EXPECT_CALL( api, OnBufferMapAsync(apiBuffer, static_cast<WGPUMapMode>(mapMode), 0, kBufferSize, _)) .WillOnce(InvokeWithoutArgs([&] { api.CallBufferMapAsyncCallback(apiBuffer, WGPUMapAsyncStatus_Success, kEmptyOutputStringView); })); - ExpectMappedRangeCall(kBufferSize, &bufferContent); + if (mapMode & wgpu::MapMode::Read) { + ExpectMappedRangeCall(); + } // Ensure that the server had a chance to respond if relevant. FlushClient(); @@ -260,7 +262,11 @@ FlushCallbacks(); } + // The buffer contents is in a member to ensure it outlives all test bodies (it is passed by + // pointer as the mocked result of GetMappedRange and can be derefenced anywhere in the test). + uint32_t mappedBufferContents = 31337; static constexpr uint64_t kBufferSize = sizeof(uint32_t); + // A successfully created buffer wgpu::Buffer buffer; WGPUBuffer apiBuffer; @@ -363,7 +369,12 @@ // Test that the callback isn't fired twice when Unmap() is called inside the callback. TEST_P(WireBufferMappingTests, UnmapInsideMapCallback) { TestCancelInCallback([&]() { buffer.Unmap(); }, - [&]() { EXPECT_CALL(api, BufferUnmap(apiBuffer)); }); + [&]() { + if (GetMapMode() & wgpu::MapMode::Write) { + ExpectMappedRangeCall(); + } + EXPECT_CALL(api, BufferUnmap(apiBuffer)); + }); } // Test that the callback isn't fired twice when Destroy() is called inside the callback. @@ -394,14 +405,13 @@ TEST_P(WireBufferMappingReadTests, MappingSuccess) { MapAsync(wgpu::MapMode::Read, 0, kBufferSize); - uint32_t bufferContent = 31337; EXPECT_CALL(api, OnBufferMapAsync(apiBuffer, WGPUMapMode_Read, 0, kBufferSize, _)) .WillOnce(InvokeWithoutArgs([&] { api.CallBufferMapAsyncCallback(apiBuffer, WGPUMapAsyncStatus_Success, kEmptyOutputStringView); })); EXPECT_CALL(api, BufferGetConstMappedRange(apiBuffer, 0, kBufferSize)) - .WillOnce(Return(&bufferContent)); + .WillOnce(Return(&mappedBufferContents)); FlushClient(); FlushFutures(); @@ -411,7 +421,7 @@ FlushCallbacks(); }); - EXPECT_EQ(bufferContent, + EXPECT_EQ(mappedBufferContents, *static_cast<const uint32_t*>(buffer.GetConstMappedRange(0, kBufferSize))); EXPECT_CALL(api, BufferUnmap(apiBuffer)).Times(1); buffer.Unmap(); @@ -424,14 +434,13 @@ // Successful map MapAsync(wgpu::MapMode::Read, 0, kBufferSize); - uint32_t bufferContent = 31337; EXPECT_CALL(api, OnBufferMapAsync(apiBuffer, WGPUMapMode_Read, 0, kBufferSize, _)) .WillOnce(InvokeWithoutArgs([&] { api.CallBufferMapAsyncCallback(apiBuffer, WGPUMapAsyncStatus_Success, kEmptyOutputStringView); })); EXPECT_CALL(api, BufferGetConstMappedRange(apiBuffer, 0, kBufferSize)) - .WillOnce(Return(&bufferContent)); + .WillOnce(Return(&mappedBufferContents)); FlushClient(); FlushFutures(); @@ -459,7 +468,7 @@ FlushCallbacks(); }); - EXPECT_EQ(bufferContent, + EXPECT_EQ(mappedBufferContents, *static_cast<const uint32_t*>(buffer.GetConstMappedRange(0, kBufferSize))); } @@ -478,7 +487,6 @@ TEST_P(WireBufferMappingWriteTests, MappingSuccess) { MapAsync(wgpu::MapMode::Write, 0, kBufferSize); - uint32_t serverBufferContent = 31337; uint32_t updatedContent = 4242; EXPECT_CALL(api, OnBufferMapAsync(apiBuffer, WGPUMapMode_Write, 0, kBufferSize, _)) @@ -486,8 +494,6 @@ api.CallBufferMapAsyncCallback(apiBuffer, WGPUMapAsyncStatus_Success, kEmptyOutputStringView); })); - EXPECT_CALL(api, BufferGetMappedRange(apiBuffer, 0, kBufferSize)) - .WillOnce(Return(&serverBufferContent)); // The map write callback always gets a buffer full of zeroes. FlushClient(); @@ -504,13 +510,15 @@ // Write something to the mapped pointer *lastMapWritePointer = updatedContent; + EXPECT_CALL(api, BufferGetMappedRange(apiBuffer, 0, kBufferSize)) + .WillOnce(Return(&mappedBufferContents)); EXPECT_CALL(api, BufferUnmap(apiBuffer)).Times(1); buffer.Unmap(); FlushClient(); // After the buffer is unmapped, the content of the buffer is updated on the server - ASSERT_EQ(serverBufferContent, updatedContent); + ASSERT_EQ(mappedBufferContents, updatedContent); } // Check that an error map write while a buffer is already mapped. @@ -518,14 +526,11 @@ // Successful map MapAsync(wgpu::MapMode::Write, 0, kBufferSize); - uint32_t bufferContent = 31337; EXPECT_CALL(api, OnBufferMapAsync(apiBuffer, WGPUMapMode_Write, 0, kBufferSize, _)) .WillOnce(InvokeWithoutArgs([&] { api.CallBufferMapAsyncCallback(apiBuffer, WGPUMapAsyncStatus_Success, kEmptyOutputStringView); })); - EXPECT_CALL(api, BufferGetMappedRange(apiBuffer, 0, kBufferSize)) - .WillOnce(Return(&bufferContent)); FlushClient(); FlushFutures(); @@ -574,11 +579,10 @@ uint32_t apiBufferData = 1234; EXPECT_CALL(api, DeviceCreateBuffer(apiDevice, _)).WillOnce(Return(apiBuffer)); - EXPECT_CALL(api, BufferGetMappedRange(apiBuffer, 0, 4)).WillOnce(Return(&apiBufferData)); - buffer = device.CreateBuffer(&descriptor); FlushClient(); + EXPECT_CALL(api, BufferGetMappedRange(apiBuffer, 0, 4)).WillOnce(Return(&apiBufferData)); EXPECT_CALL(api, BufferUnmap(apiBuffer)).Times(1); buffer.Unmap(); FlushClient(); @@ -590,10 +594,7 @@ descriptor.size = kBufferSize; descriptor.mappedAtCreation = true; - uint32_t apiBufferData = 1234; EXPECT_CALL(api, DeviceCreateBuffer(apiDevice, _)).WillOnce(Return(apiBuffer)); - EXPECT_CALL(api, BufferGetMappedRange(apiBuffer, 0, 4)).WillOnce(Return(&apiBufferData)); - buffer = device.CreateBuffer(&descriptor); FlushClient(); @@ -611,11 +612,10 @@ uint32_t apiBufferData = 1234; EXPECT_CALL(api, DeviceCreateBuffer(apiDevice, _)).WillOnce(Return(apiBuffer)); - EXPECT_CALL(api, BufferGetMappedRange(apiBuffer, 0, 4)).WillOnce(Return(&apiBufferData)); - buffer = device.CreateBuffer(&descriptor); FlushClient(); + EXPECT_CALL(api, BufferGetMappedRange(apiBuffer, 0, 4)).WillOnce(Return(&apiBufferData)); EXPECT_CALL(api, BufferUnmap(apiBuffer)).Times(1); buffer.Unmap(); FlushClient(); @@ -627,8 +627,6 @@ api.CallBufferMapAsyncCallback(apiBuffer, WGPUMapAsyncStatus_Success, kEmptyOutputStringView); })); - EXPECT_CALL(api, BufferGetMappedRange(apiBuffer, 0, kBufferSize)) - .WillOnce(Return(&apiBufferData)); FlushClient(); FlushFutures(); @@ -647,8 +645,6 @@ uint32_t apiBufferData = 1234; EXPECT_CALL(api, DeviceCreateBuffer(apiDevice, _)).WillOnce(Return(apiBuffer)); - EXPECT_CALL(api, BufferGetMappedRange(apiBuffer, 0, 4)).WillOnce(Return(&apiBufferData)); - buffer = device.CreateBuffer(&descriptor); FlushClient(); @@ -673,6 +669,7 @@ EXPECT_NE(nullptr, static_cast<const uint32_t*>(buffer.GetConstMappedRange(0, kBufferSize))); + EXPECT_CALL(api, BufferGetMappedRange(apiBuffer, 0, 4)).WillOnce(Return(&apiBufferData)); EXPECT_CALL(api, BufferUnmap(apiBuffer)).Times(1); buffer.Unmap(); @@ -767,14 +764,15 @@ wgpu::MapMode mapMode = GetMapMode(); MapAsync(mapMode, 0, kBufferSize); - uint32_t bufferContent = 0; EXPECT_CALL(api, OnBufferMapAsync(apiBuffer, static_cast<WGPUMapMode>(mapMode), 0, kBufferSize, _)) .WillOnce(InvokeWithoutArgs([&] { api.CallBufferMapAsyncCallback(apiBuffer, WGPUMapAsyncStatus_Success, kEmptyOutputStringView); })); - ExpectMappedRangeCall(kBufferSize, &bufferContent); + if (mapMode & wgpu::MapMode::Read) { + ExpectMappedRangeCall(); + } FlushClient(); ExpectWireCallbacksWhen([&](auto& mockCb) { @@ -803,14 +801,15 @@ MapAsync(mapMode, 0, kBufferSize); // Calls for the first successful map. - uint32_t bufferContent = 0; EXPECT_CALL(api, OnBufferMapAsync(apiBuffer, static_cast<WGPUMapMode>(mapMode), 0, kBufferSize, _)) .WillOnce(InvokeWithoutArgs([&] { api.CallBufferMapAsyncCallback(apiBuffer, WGPUMapAsyncStatus_Success, kEmptyOutputStringView); })); - ExpectMappedRangeCall(kBufferSize, &bufferContent); + if (mapMode & wgpu::MapMode::Read) { + ExpectMappedRangeCall(); + } if (IsSpontaneous()) { // In spontaneous mode, the second map on the pending immediately calls the callback. @@ -849,7 +848,6 @@ // Test that GetMapState() returns map state as expected TEST_P(WireBufferMappingTests, GetMapState) { wgpu::MapMode mapMode = GetMapMode(); - uint32_t bufferContent = 31337; // Server-side success case { @@ -862,8 +860,10 @@ api.CallBufferMapAsyncCallback(apiBuffer, WGPUMapAsyncStatus_Success, kEmptyOutputStringView); })); - ExpectMappedRangeCall(kBufferSize, &bufferContent); MapAsync(mapMode, 0, kBufferSize); + if (mapMode & wgpu::MapMode::Read) { + ExpectMappedRangeCall(); + } // Map state should become pending immediately after map async call. ASSERT_EQ(buffer.GetMapState(), wgpu::BufferMapState::Pending); @@ -887,6 +887,9 @@ ASSERT_EQ(buffer.GetMapState(), wgpu::BufferMapState::Mapped); } + if (mapMode & wgpu::MapMode::Write) { + ExpectMappedRangeCall(); + } EXPECT_CALL(api, BufferUnmap(apiBuffer)).Times(1); buffer.Unmap(); FlushClient(); @@ -933,14 +936,15 @@ wgpu::MapMode mapMode = GetMapMode(); MapAsync(mapMode, 0, kBufferSize); - uint32_t bufferContent = 0; EXPECT_CALL(api, OnBufferMapAsync(apiBuffer, static_cast<WGPUMapMode>(mapMode), 0, kBufferSize, _)) .WillOnce(InvokeWithoutArgs([&] { api.CallBufferMapAsyncCallback(apiBuffer, WGPUMapAsyncStatus_Success, kEmptyOutputStringView); })); - ExpectMappedRangeCall(kBufferSize, &bufferContent); + if (mapMode & wgpu::MapMode::Read) { + ExpectMappedRangeCall(); + } FlushClient(); @@ -964,14 +968,15 @@ wgpu::MapMode mapMode = GetMapMode(); MapAsync(mapMode, 0, kBufferSize); - uint32_t bufferContent = 0; EXPECT_CALL(api, OnBufferMapAsync(apiBuffer, static_cast<WGPUMapMode>(mapMode), 0, kBufferSize, _)) .WillOnce(InvokeWithoutArgs([&] { api.CallBufferMapAsyncCallback(apiBuffer, WGPUMapAsyncStatus_Success, kEmptyOutputStringView); })); - ExpectMappedRangeCall(kBufferSize, &bufferContent); + if (mapMode & wgpu::MapMode::Read) { + ExpectMappedRangeCall(); + } FlushClient(); FlushFutures();
diff --git a/src/dawn/tests/unittests/wire/WireMemoryTransferServiceTests.cpp b/src/dawn/tests/unittests/wire/WireMemoryTransferServiceTests.cpp index 54f8392..17933dc 100644 --- a/src/dawn/tests/unittests/wire/WireMemoryTransferServiceTests.cpp +++ b/src/dawn/tests/unittests/wire/WireMemoryTransferServiceTests.cpp
@@ -138,10 +138,6 @@ // When the commands are flushed, the server should appropriately deserialize the handles. auto serverHandles = ExpectHandleDeserialization(true); - if (GetParam().mMappedAtCreation) { - EXPECT_CALL(api, BufferGetMappedRange(apiBuffer, 0, kBufferSize)) - .WillOnce(Return(&mServerBufferContent)); - } FlushClient(); return std::make_tuple(apiBuffer, buffer, clientHandles, serverHandles); @@ -374,8 +370,6 @@ auto* clientHandle = std::get<MockClientWriteHandle*>(clientHandles); ASSERT_THAT(clientHandle, NotNull()); EXPECT_CALL(*clientHandle, GetData).WillOnce(Return(&mClientBufferContent)); - EXPECT_CALL(api, BufferGetMappedRange(apiBuffer, 0, kBufferSize)) - .WillOnce(Return(&mClientBufferContent)); buffer.MapAsync(mode, 0, kBufferSize, wgpu::CallbackMode::AllowSpontaneous, mMapAsyncCb.Callback()); @@ -409,6 +403,8 @@ // The server should deserialize into its buffer when the client flushes. EXPECT_CALL(api, BufferUnmap(apiBuffer)).Times(1); + EXPECT_CALL(api, BufferGetMappedRange(apiBuffer, 0, kBufferSize)) + .WillOnce(Return(&mServerBufferContent)); ExpectServerDeserializeData(true, serverHandles); FlushClient(); @@ -683,8 +679,7 @@ } case wgpu::MapMode::Write: { EXPECT_CALL(mMapAsyncCb, Call(wgpu::MapAsyncStatus::Success, _)).Times(1); - EXPECT_CALL(api, BufferGetMappedRange(apiBuffer, 0, kBufferSize)) - .WillOnce(Return(&mClientBufferContent)); + auto* clientHandle = std::get<MockClientWriteHandle*>(clientHandles); ASSERT_THAT(clientHandle, NotNull()); EXPECT_CALL(*clientHandle, GetData).WillOnce(Return(&mClientBufferContent)); @@ -699,6 +694,8 @@ buffer.Unmap(); // Mock that the server fails to deserialize into its buffer when the client flushes. + EXPECT_CALL(api, BufferGetMappedRange(apiBuffer, 0, kBufferSize)) + .WillOnce(Return(&mServerBufferContent)); ExpectServerDeserializeData(false, serverHandles); FlushClient(false); break; @@ -791,6 +788,8 @@ buffer.Unmap(); // Mock that the server fails to deserialize into its buffer when the client flushes. + EXPECT_CALL(api, BufferGetMappedRange(apiBuffer, 0, kBufferSize)) + .WillOnce(Return(&mServerBufferContent)); ExpectServerDeserializeData(false, serverHandles); FlushClient(false);
diff --git a/src/dawn/tests/unittests/wire/WireSpecificCommandTests.cpp b/src/dawn/tests/unittests/wire/WireSpecificCommandTests.cpp new file mode 100644 index 0000000..3f8087b --- /dev/null +++ b/src/dawn/tests/unittests/wire/WireSpecificCommandTests.cpp
@@ -0,0 +1,139 @@ +// Copyright 2026 The Dawn & Tint Authors +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "dawn/common/StringViewUtils.h" +#include "dawn/tests/unittests/wire/WireTest.h" +#include "dawn/wire/ChunkedCommandSerializer.h" +#include "dawn/wire/Wire.h" +#include "dawn/wire/WireClient.h" +#include "dawn/wire/WireCmd_autogen.h" +#include "dawn/wire/WireServer.h" +#include "dawn/wire/client/Client.h" + +namespace dawn::wire { +namespace { + +using testing::_; +using testing::InvokeWithoutArgs; +using testing::Return; + +// Fixture that helps execute specific commands through the wire that may not be possible to trigger +// through usage of the dawn::wire::client. It is even more change detecting than regular dawn::wire +// tests so we should use it only when there are no alternatives. +class WireSpecificCommandTests : public WireTest { + protected: + template <typename Cmd> + void AddSpecificServerCmd(const Cmd& cmd) { + CommandSerializer* c2s = GetC2SSerializer(); + ChunkedCommandSerializer serializer(c2s); + + serializer.SerializeCommand(cmd, *GetWireClient()->GetImplForTesting()); + } +}; + +// Regression test for https://issues.chromium.org/492139412 where a server receiving +// Device::Destroy wouldn't realize that the buffers got unmapped and would try to write into them. +// While it's not exactly possible to replicate the issue with WireTests since there is no +// dawn::native backend that will unmap buffers on destroy, we can check that the ordering of +// commands in the server is such that it will check that the buffer is mapped before writing into +// it. +TEST_F(WireSpecificCommandTests, UpdateMappedDataAfterDeviceDestroy_MappedAtCreation) { + // Create a mapped buffer. + wgpu::BufferDescriptor descriptor = {}; + descriptor.size = 4; + descriptor.usage = wgpu::BufferUsage::CopySrc; + descriptor.mappedAtCreation = true; + wgpu::Buffer buffer = device.CreateBuffer(&descriptor); + + WGPUBuffer apiBuffer = api.GetNewBuffer(); + EXPECT_CALL(api, DeviceCreateBuffer(apiDevice, _)) + .WillOnce(Return(apiBuffer)) + .RetiresOnSaturation(); + FlushClient(); + + // Force a device destroy without giving the wire::client a chance to unmap client-side buffers. + DeviceDestroyCmd cmd; + cmd.self = device.Get(); + AddSpecificServerCmd(cmd); + + EXPECT_CALL(api, DeviceDestroy(apiDevice)).Times(1); + FlushClient(); + + // A call to unmap will get a nullptr mapped range and should not write to it! (if it were, we'd + // see a crash here since it would write to nullptr). + EXPECT_CALL(api, BufferGetMappedRange(apiBuffer, 0, 4)).WillOnce(Return(nullptr)); + EXPECT_CALL(api, BufferUnmap(apiBuffer)).Times(1); + buffer.Unmap(); + FlushClient(); +} + +// The same test at an offset, to check that it doesn't allow bypassing the null check. It was a +// bug found during review of the fix. +TEST_F(WireSpecificCommandTests, UpdateMappedDataAfterDeviceDestroy_MapWriteOffsetNonZero) { + // Create a mapped buffer. + wgpu::BufferDescriptor descriptor = {}; + descriptor.size = 8; + descriptor.usage = wgpu::BufferUsage::MapWrite; + wgpu::Buffer buffer = device.CreateBuffer(&descriptor); + + WGPUBuffer apiBuffer = api.GetNewBuffer(); + EXPECT_CALL(api, DeviceCreateBuffer(apiDevice, _)) + .WillOnce(Return(apiBuffer)) + .RetiresOnSaturation(); + FlushClient(); + + // Map the buffer + buffer.MapAsync(wgpu::MapMode::Write, 4, 4, wgpu::CallbackMode::AllowProcessEvents, + [](wgpu::MapAsyncStatus status, wgpu::StringView) {}); + EXPECT_CALL(api, OnBufferMapAsync(apiBuffer, WGPUMapMode_Write, 4, 4, _)) + .WillOnce(InvokeWithoutArgs([&] { + api.CallBufferMapAsyncCallback(apiBuffer, WGPUMapAsyncStatus_Success, + kEmptyOutputStringView); + })); + + FlushClient(); + FlushServer(); + instance.ProcessEvents(); + + // Force a device destroy without giving the wire::client a chance to unmap client-side buffers. + DeviceDestroyCmd cmd; + cmd.self = device.Get(); + AddSpecificServerCmd(cmd); + + EXPECT_CALL(api, DeviceDestroy(apiDevice)).Times(1); + FlushClient(); + + // A call to unmap will get a nullptr mapped range and should not write to it! (if it were, we'd + // see a crash here since it would write to nullptr). + EXPECT_CALL(api, BufferGetMappedRange(apiBuffer, 4, 4)).WillOnce(Return(nullptr)); + EXPECT_CALL(api, BufferUnmap(apiBuffer)).Times(1); + buffer.Unmap(); + FlushClient(); +} + +} // anonymous namespace +} // namespace dawn::wire
diff --git a/src/dawn/tests/unittests/wire/WireTest.cpp b/src/dawn/tests/unittests/wire/WireTest.cpp index f22285e..7d26e11 100644 --- a/src/dawn/tests/unittests/wire/WireTest.cpp +++ b/src/dawn/tests/unittests/wire/WireTest.cpp
@@ -246,6 +246,14 @@ return mWireClient.get(); } +wire::CommandSerializer* WireTest::GetC2SSerializer() { + return mC2sBuf.get(); +} + +wire::CommandSerializer* WireTest::GetS2CSerializer() { + return mS2cBuf.get(); +} + size_t WireTest::GetC2SMaxAllocationSize() { return mC2sBuf->GetMaximumAllocationSize(); }
diff --git a/src/dawn/tests/unittests/wire/WireTest.h b/src/dawn/tests/unittests/wire/WireTest.h index 41e8917..6688a52 100644 --- a/src/dawn/tests/unittests/wire/WireTest.h +++ b/src/dawn/tests/unittests/wire/WireTest.h
@@ -127,6 +127,7 @@ } while (0) namespace wire { +class CommandSerializer; class WireClient; class WireServer; namespace client { @@ -176,8 +177,11 @@ wgpu::Queue queue; WGPUQueue apiQueue; - dawn::wire::WireServer* GetWireServer(); - dawn::wire::WireClient* GetWireClient(); + wire::WireServer* GetWireServer(); + wire::WireClient* GetWireClient(); + + wire::CommandSerializer* GetC2SSerializer(); + wire::CommandSerializer* GetS2CSerializer(); size_t GetC2SMaxAllocationSize(); @@ -185,17 +189,17 @@ void DeleteClient(); private: - virtual dawn::wire::client::MemoryTransferService* GetClientMemoryTransferService(); - virtual dawn::wire::server::MemoryTransferService* GetServerMemoryTransferService(); + virtual wire::client::MemoryTransferService* GetClientMemoryTransferService(); + virtual wire::server::MemoryTransferService* GetServerMemoryTransferService(); // Devices created on the server MUST call Device.Destroy at least once. This map is used to // ensure that this invariant holds true for any devices returned. absl::flat_hash_map<WGPUDevice, bool> mDeviceDestroyed; - std::unique_ptr<dawn::wire::WireServer> mWireServer; - std::unique_ptr<dawn::wire::WireClient> mWireClient; - std::unique_ptr<dawn::utils::TerribleCommandBuffer> mS2cBuf; - std::unique_ptr<dawn::utils::TerribleCommandBuffer> mC2sBuf; + std::unique_ptr<wire::WireServer> mWireServer; + std::unique_ptr<wire::WireClient> mWireClient; + std::unique_ptr<utils::TerribleCommandBuffer> mS2cBuf; + std::unique_ptr<utils::TerribleCommandBuffer> mC2sBuf; }; } // namespace dawn
diff --git a/src/dawn/wire/WireClient.cpp b/src/dawn/wire/WireClient.cpp index 98b260a..8244522 100644 --- a/src/dawn/wire/WireClient.cpp +++ b/src/dawn/wire/WireClient.cpp
@@ -95,6 +95,10 @@ return wireDevice->GetWireHandle(mImpl.get()); } +client::Client* WireClient::GetImplForTesting() { + return mImpl.get(); +} + namespace client { MemoryTransferService::MemoryTransferService() = default;
diff --git a/src/dawn/wire/server/ObjectStorage.h b/src/dawn/wire/server/ObjectStorage.h index c9ecb25..b6ee390 100644 --- a/src/dawn/wire/server/ObjectStorage.h +++ b/src/dawn/wire/server/ObjectStorage.h
@@ -63,12 +63,9 @@ template <typename T> struct ObjectData : public ObjectDataBase<T> {}; -enum class BufferMapWriteState { Unmapped, Mapped, MapError }; - struct BufferMapState { std::unique_ptr<MemoryTransferService::ReadHandle> readHandle = nullptr; std::unique_ptr<MemoryTransferService::WriteHandle> writeHandle = nullptr; - BufferMapWriteState writeState = BufferMapWriteState::Unmapped; }; template <>
diff --git a/src/dawn/wire/server/ServerBuffer.cpp b/src/dawn/wire/server/ServerBuffer.cpp index 059243b..3b927e7 100644 --- a/src/dawn/wire/server/ServerBuffer.cpp +++ b/src/dawn/wire/server/ServerBuffer.cpp
@@ -32,6 +32,7 @@ #include <limits> #include <memory> +#include <span> #include "dawn/common/Assert.h" #include "dawn/common/StringViewUtils.h" @@ -63,7 +64,6 @@ // don't assert it's non-null mapState->writeHandle = nullptr; } - mapState->writeState = BufferMapWriteState::Unmapped; }); return WireResult::Success; @@ -78,7 +78,6 @@ buffer->mapState.Use([](auto mapState) { mapState->readHandle = nullptr; mapState->writeHandle = nullptr; - mapState->writeState = BufferMapWriteState::Unmapped; }); return WireResult::Success; @@ -163,10 +162,8 @@ if (buffer->handle == nullptr) { DAWN_ASSERT(descriptor->mappedAtCreation); // A null buffer indicates that mapping-at-creation failed inside createBuffer. - // - Unmark the buffer as allocated so we will skip freeing it. + // Unmark the buffer as allocated so we will skip freeing it. buffer->state = AllocationState::Reserved; - // - Remember the buffer is an error so we will skip subsequent mapping operations. - mapState->writeState = BufferMapWriteState::MapError; return WireResult::Success; } @@ -179,23 +176,6 @@ } DAWN_ASSERT(writeHandle != nullptr); mapState->writeHandle.reset(writeHandle); - writeHandle->SetDataLength(descriptor->size); - - if (descriptor->mappedAtCreation) { - void* mapping = mProcs->bufferGetMappedRange(buffer->handle, 0, descriptor->size); - if (mapping == nullptr) { - DAWN_ASSERT(descriptor->size % 4 != 0); - // GetMappedRange can still fail if the buffer's size isn't aligned. - // - Remember the buffer is an error so we will skip subsequent mapping - // operations. - mapState->writeState = BufferMapWriteState::MapError; - return WireResult::Success; - } - DAWN_ASSERT(mapping != nullptr); - writeHandle->SetTarget(mapping); - - mapState->writeState = BufferMapWriteState::Mapped; - } } if (isReadMode) { @@ -207,7 +187,6 @@ return WireResult::FatalError; } DAWN_ASSERT(readHandle != nullptr); - mapState->readHandle.reset(readHandle); } @@ -226,25 +205,40 @@ } return buffer->mapState.Use([&](auto mapState) { - switch (mapState->writeState) { - case BufferMapWriteState::Unmapped: - return WireResult::FatalError; - case BufferMapWriteState::MapError: - // The buffer is mapped but there was an error allocating mapped data. - // Do not perform the memcpy. - return WireResult::Success; - case BufferMapWriteState::Mapped: - break; + uint8_t* mappedData = + static_cast<uint8_t*>(mProcs->bufferGetMappedRange(buffer->handle, offset, size)); + + // There are a few valid reasons why getting the mapped range would fail here: + // - The buffer was implicitly unmapped because of a device.Destroy() call. + // - The buffer was an error buffer created just to replace an OOM mappedAtCreation buffer. + // Unfortunately validating exactly that the failure is due to a valid reason and not + // another is difficult, so we return WireResult::Success even for misuses of the wire + // protocol (like a size being larger than the buffer's size, etc). + if (mappedData == nullptr) { + return WireResult::Success; } + // TODO(https://issues.chromium.org/492456046): We would like to map only the `offset` and + // `size` here but the Chromium implementation of DeserializeDataUpdate uses `offset` to + // offset both the target data and it's shmem pointer. So the pointer passed in SetTarget + // must be for the start of the buffer. Fix this somehow when spanifying the interfaces but + // for now we need to duplicate the overflow check that's done in GetMappedRange. + mappedData -= offset; + + // Note that offset + size was checked to not overflow in GetMappedRange above. + std::span<uint8_t> mappedRange = {mappedData, static_cast<size_t>(offset + size)}; + + // However it is easy to check for misuses of the wire protocol to UpdateMappedData without + // a WriteHandle. if (!mapState->writeHandle) { - // This check is performed after the check for the MapError state. It is permissible - // to Unmap and attempt to update mapped data of an error buffer. return WireResult::FatalError; } // Deserialize the flush info and flush updated data from the handle into the target - // of the handle. The target is set via WriteHandle::SetTarget. + // of the handle. The target is set via WriteHandle::SetTarget/SetDataLength. + mapState->writeHandle->SetDataLength(mappedRange.size()); + mapState->writeHandle->SetTarget(mappedRange.data()); + if (!mapState->writeHandle->DeserializeDataUpdate( writeDataUpdateInfo, static_cast<size_t>(writeDataUpdateInfoLength), static_cast<size_t>(offset), static_cast<size_t>(size))) { @@ -301,18 +295,6 @@ break; } case WGPUMapMode_Write: { - buffer->mapState.Use([&](auto mapState) { - // The in-flight map request returned successfully. - mapState->writeState = BufferMapWriteState::Mapped; - // Set the target of the WriteHandle to the mapped buffer data. - // Note that writeHandle's target always refers to the buffer base address, but we - // call getMappedRange exactly with the range of data that is potentially modified - // (i.e. we don't want getMappedRange(0, wholeBufferSize) if only a subset of the - // buffer is actually mapped) in case the implementation does some range tracking. - mapState->writeHandle->SetTarget(static_cast<uint8_t*>(mProcs->bufferGetMappedRange( - data->bufferObj, data->offset, data->size)) - - data->offset); - }); SerializeCommand(cmd); break; }