tools/remote-compile: Handle socket disconnects If the socket disconnected mid-communication, the server could spin, waiting for new data. Actually handle recv() errors, preventing the server spinning itself to death. Also fix code style to be more tint-like (snake_case variables, PascalCase functions) Change-Id: I9fcbfde303a8624e7e1ff87abd33581589f4da42 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/105142 Auto-Submit: Ben Clayton <bclayton@google.com> Reviewed-by: Corentin Wallez <cwallez@chromium.org> Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/tools/src/cmd/remote-compile/main.cc b/tools/src/cmd/remote-compile/main.cc index 20fb911..53a7deb 100644 --- a/tools/src/cmd/remote-compile/main.cc +++ b/tools/src/cmd/remote-compile/main.cc
@@ -145,6 +145,8 @@ } size -= n; buf += n; + } else { + error = "Socket::Read() failed"; } } return error.empty(); @@ -238,13 +240,15 @@ MESSAGE& m) { Message::Type ty; s >> ty; - if (ty == m.type) { - m.Serialize([&s](auto& value) { s >> value; }); - } else { - std::stringstream ss; - ss << "expected message type " << static_cast<int>(m.type) << ", got " - << static_cast<int>(ty); - s.error = ss.str(); + if (s.error.empty()) { + if (ty == m.type) { + m.Serialize([&s](auto& value) { s >> value; }); + } else { + std::stringstream ss; + ss << "expected message type " << static_cast<int>(m.type) << ", got " + << static_cast<int>(ty); + s.error = ss.str(); + } } return s; } @@ -291,8 +295,7 @@ continue; } - // xcrun flags are ignored so this executable can be used as a replacement - // for xcrun. + // xcrun flags are ignored so this executable can be used as a replacement for xcrun. if ((arg == "-x" || arg == "-sdk") && (i < argc - 1)) { i++; continue; @@ -357,7 +360,7 @@ ConnectionRequest req; stream >> req; if (!stream.error.empty()) { - printf("%s\n", stream.error.c_str()); + DEBUG("%s", stream.error.c_str()); return; } ConnectionResponse resp; @@ -374,7 +377,7 @@ CompileRequest req; stream >> req; if (!stream.error.empty()) { - printf("%s\n", stream.error.c_str()); + DEBUG("%s\n", stream.error.c_str()); return; } #ifdef TINT_ENABLE_MSL_COMPILATION_USING_METAL_API
diff --git a/tools/src/cmd/remote-compile/rwmutex.h b/tools/src/cmd/remote-compile/rwmutex.h index 6970ff3..0a6381d 100644 --- a/tools/src/cmd/remote-compile/rwmutex.h +++ b/tools/src/cmd/remote-compile/rwmutex.h
@@ -29,56 +29,56 @@ public: inline RWMutex() = default; - /// lockReader() locks the mutex for reading. + /// LockReader() locks the mutex for reading. /// Multiple read locks can be held while there are no writer locks. - inline void lockReader(); + inline void LockReader(); - /// unlockReader() unlocks the mutex for reading. - inline void unlockReader(); + /// UnlockReader() unlocks the mutex for reading. + inline void UnlockReader(); - /// lockWriter() locks the mutex for writing. - /// If the lock is already locked for reading or writing, lockWriter blocks + /// LockWriter() locks the mutex for writing. + /// If the lock is already locked for reading or writing, LockWriter blocks /// until the lock is available. - inline void lockWriter(); + inline void LockWriter(); - /// unlockWriter() unlocks the mutex for writing. - inline void unlockWriter(); + /// UnlockWriter() unlocks the mutex for writing. + inline void UnlockWriter(); private: RWMutex(const RWMutex&) = delete; RWMutex& operator=(const RWMutex&) = delete; - int readLocks = 0; - int pendingWriteLocks = 0; + int read_locks = 0; + int pending_write_locks = 0; std::mutex mutex; std::condition_variable cv; }; -void RWMutex::lockReader() { +void RWMutex::LockReader() { std::unique_lock<std::mutex> lock(mutex); - readLocks++; + read_locks++; } -void RWMutex::unlockReader() { +void RWMutex::UnlockReader() { std::unique_lock<std::mutex> lock(mutex); - readLocks--; - if (readLocks == 0 && pendingWriteLocks > 0) { + read_locks--; + if (read_locks == 0 && pending_write_locks > 0) { cv.notify_one(); } } -void RWMutex::lockWriter() { +void RWMutex::LockWriter() { std::unique_lock<std::mutex> lock(mutex); - if (readLocks > 0) { - pendingWriteLocks++; - cv.wait(lock, [&] { return readLocks == 0; }); - pendingWriteLocks--; + if (read_locks > 0) { + pending_write_locks++; + cv.wait(lock, [&] { return read_locks == 0; }); + pending_write_locks--; } lock.release(); // Keep lock held } -void RWMutex::unlockWriter() { - if (pendingWriteLocks > 0) { +void RWMutex::UnlockWriter() { + if (pending_write_locks > 0) { cv.notify_one(); } mutex.unlock(); @@ -115,12 +115,12 @@ }; RLock::RLock(RWMutex& mutex) : m(&mutex) { - m->lockReader(); + m->LockReader(); } RLock::~RLock() { if (m != nullptr) { - m->unlockReader(); + m->UnlockReader(); } } @@ -167,12 +167,12 @@ }; WLock::WLock(RWMutex& mutex) : m(&mutex) { - m->lockWriter(); + m->LockWriter(); } WLock::~WLock() { if (m != nullptr) { - m->unlockWriter(); + m->UnlockWriter(); } }
diff --git a/tools/src/cmd/remote-compile/socket.cc b/tools/src/cmd/remote-compile/socket.cc index a16940a..91e87a0 100644 --- a/tools/src/cmd/remote-compile/socket.cc +++ b/tools/src/cmd/remote-compile/socket.cc
@@ -32,7 +32,7 @@ #if defined(_WIN32) #include <atomic> namespace { -std::atomic<int> wsaInitCount = {0}; +std::atomic<int> wsa_init_count = {0}; } // anonymous namespace #else #include <fcntl.h> @@ -43,24 +43,24 @@ namespace { constexpr SOCKET InvalidSocket = static_cast<SOCKET>(-1); -void init() { +void Init() { #if defined(_WIN32) - if (wsaInitCount++ == 0) { - WSADATA winsockData; - (void)WSAStartup(MAKEWORD(2, 2), &winsockData); + if (wsa_init_count++ == 0) { + WSADATA winsock_data; + (void)WSAStartup(MAKEWORD(2, 2), &winsock_data); } #endif } -void term() { +void Term() { #if defined(_WIN32) - if (--wsaInitCount == 0) { + if (--wsa_init_count == 0) { WSACleanup(); } #endif } -bool setBlocking(SOCKET s, bool blocking) { +bool SetBlocking(SOCKET s, bool blocking) { #if defined(_WIN32) u_long mode = blocking ? 0 : 1; return ioctlsocket(s, FIONBIO, &mode) == NO_ERROR; @@ -74,7 +74,7 @@ #endif } -bool errored(SOCKET s) { +bool Errored(SOCKET s) { if (s == InvalidSocket) { return true; } @@ -87,7 +87,7 @@ class Impl : public Socket { public: static std::shared_ptr<Impl> create(const char* address, const char* port) { - init(); + Init(); addrinfo hints = {}; hints.ai_family = AF_INET; @@ -106,12 +106,12 @@ if (info) { auto socket = ::socket(info->ai_family, info->ai_socktype, info->ai_protocol); auto out = std::make_shared<Impl>(info, socket); - out->setOptions(); + out->SetOptions(); return out; } freeaddrinfo(info); - term(); + Term(); return nullptr; } @@ -121,16 +121,16 @@ ~Impl() { freeaddrinfo(info); Close(); - term(); + Term(); } template <typename FUNCTION> - void lock(FUNCTION&& f) { + void Lock(FUNCTION&& f) { RLock l(mutex); f(s, info); } - void setOptions() { + void SetOptions() { RLock l(mutex); if (s == InvalidSocket) { return; @@ -157,7 +157,7 @@ bool IsOpen() override { { RLock l(mutex); - if ((s != InvalidSocket) && !errored(s)) { + if ((s != InvalidSocket) && !Errored(s)) { return true; } } @@ -188,12 +188,20 @@ } size_t Read(void* buffer, size_t bytes) override { - RLock lock(mutex); - if (s == InvalidSocket) { - return 0; + { + RLock lock(mutex); + if (s == InvalidSocket) { + return 0; + } + size_t len = recv(s, reinterpret_cast<char*>(buffer), static_cast<int>(bytes), 0); + if (len > 0) { + return len; + } } - auto len = recv(s, reinterpret_cast<char*>(buffer), static_cast<int>(bytes), 0); - return (len < 0) ? 0 : len; + // Socket closed or errored + WLock lock(mutex); + s = InvalidSocket; + return 0; } bool Write(const void* buffer, size_t bytes) override { @@ -209,11 +217,13 @@ std::shared_ptr<Socket> Accept() override { std::shared_ptr<Impl> out; - lock([&](SOCKET socket, const addrinfo*) { + Lock([&](SOCKET socket, const addrinfo*) { if (socket != InvalidSocket) { - init(); - out = std::make_shared<Impl>(::accept(socket, 0, 0)); - out->setOptions(); + Init(); + if (auto s = ::accept(socket, 0, 0); s >= 0) { + out = std::make_shared<Impl>(s); + out->SetOptions(); + } } }); return out; @@ -232,7 +242,7 @@ if (!impl) { return nullptr; } - impl->lock([&](SOCKET socket, const addrinfo* info) { + impl->Lock([&](SOCKET socket, const addrinfo* info) { if (bind(socket, info->ai_addr, static_cast<int>(info->ai_addrlen)) != 0) { impl.reset(); return; @@ -248,46 +258,46 @@ std::shared_ptr<Socket> Socket::Connect(const char* address, const char* port, - uint32_t timeoutMillis) { + uint32_t timeout_ms) { auto impl = Impl::create(address, port); if (!impl) { return nullptr; } std::shared_ptr<Socket> out; - impl->lock([&](SOCKET socket, const addrinfo* info) { + impl->Lock([&](SOCKET socket, const addrinfo* info) { if (socket == InvalidSocket) { return; } - if (timeoutMillis == 0) { + if (timeout_ms == 0) { if (::connect(socket, info->ai_addr, static_cast<int>(info->ai_addrlen)) == 0) { out = impl; } return; } - if (!setBlocking(socket, false)) { + if (!SetBlocking(socket, false)) { return; } auto res = ::connect(socket, info->ai_addr, static_cast<int>(info->ai_addrlen)); if (res == 0) { - if (setBlocking(socket, true)) { + if (SetBlocking(socket, true)) { out = impl; } } else { - const auto microseconds = timeoutMillis * 1000; + const auto timeout_us = timeout_ms * 1000; fd_set fdset; FD_ZERO(&fdset); FD_SET(socket, &fdset); timeval tv; - tv.tv_sec = microseconds / 1000000; - tv.tv_usec = microseconds - static_cast<uint32_t>(tv.tv_sec * 1000000); + tv.tv_sec = timeout_us / 1000000; + tv.tv_usec = timeout_us - static_cast<uint32_t>(tv.tv_sec * 1000000); res = select(static_cast<int>(socket + 1), nullptr, &fdset, nullptr, &tv); - if (res > 0 && !errored(socket) && setBlocking(socket, true)) { + if (res > 0 && !Errored(socket) && SetBlocking(socket, true)) { out = impl; } }
diff --git a/tools/src/cmd/remote-compile/socket.h b/tools/src/cmd/remote-compile/socket.h index 514ebac..59115ea 100644 --- a/tools/src/cmd/remote-compile/socket.h +++ b/tools/src/cmd/remote-compile/socket.h
@@ -24,29 +24,27 @@ /// Connects to the given TCP address and port. /// @param address the target socket address /// @param port the target socket port - /// @param timeoutMillis the timeout for the connection attempt. - /// If timeoutMillis is non-zero and no connection was made before - /// timeoutMillis milliseconds, then nullptr is returned. + /// @param timeout_ms the timeout for the connection attempt. + /// If timeout_ms is non-zero and no connection was made before timeout_ms milliseconds, + /// then nullptr is returned. /// @returns the connected Socket, or nullptr on failure static std::shared_ptr<Socket> Connect(const char* address, const char* port, - uint32_t timeoutMillis); + uint32_t timeout_ms); /// Begins listening for connections on the given TCP address and port. /// Call Accept() on the returned Socket to block and wait for a connection. - /// @param address the socket address to listen on. Use "localhost" for - /// connections from only this machine, or an empty string to allow - /// connections from any incoming address. + /// @param address the socket address to listen on. Use "localhost" for connections from only + /// this machine, or an empty string to allow connections from any incoming address. /// @param port the socket port to listen on /// @returns the Socket that listens for connections static std::shared_ptr<Socket> Listen(const char* address, const char* port); - /// Attempts to read at most `n` bytes into buffer, returning the actual - /// number of bytes read. + /// Attempts to read at most `n` bytes into buffer, returning the actual number of bytes read. /// read() will block until the socket is closed or at least one byte is read. /// @param buffer the output buffer. Must be at least `n` bytes in size. /// @param n the maximum number of bytes to read - /// @return the number of bytes read, or 0 if the socket was closed + /// @return the number of bytes read, or 0 if the socket was closed or errored virtual size_t Read(void* buffer, size_t n) = 0; /// Writes `n` bytes from buffer into the socket. @@ -62,8 +60,7 @@ /// Closes the socket. virtual void Close() = 0; - /// Blocks for a connection to be made to the listening port, or for the - /// Socket to be closed. + /// Blocks for a connection to be made to the listening port, or for the Socket to be closed. /// @returns a pointer to the next established incoming connection virtual std::shared_ptr<Socket> Accept() = 0; };