tools/remote-compile: Handle socket disconnects
If the socket disconnected mid-communication, the server could spin, waiting for new data.
Actually handle recv() errors, preventing the server spinning itself to death.
Also fix code style to be more tint-like (snake_case variables, PascalCase functions)
Change-Id: I9fcbfde303a8624e7e1ff87abd33581589f4da42
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/105142
Auto-Submit: Ben Clayton <bclayton@google.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/tools/src/cmd/remote-compile/main.cc b/tools/src/cmd/remote-compile/main.cc
index 20fb911..53a7deb 100644
--- a/tools/src/cmd/remote-compile/main.cc
+++ b/tools/src/cmd/remote-compile/main.cc
@@ -145,6 +145,8 @@
}
size -= n;
buf += n;
+ } else {
+ error = "Socket::Read() failed";
}
}
return error.empty();
@@ -238,13 +240,15 @@
MESSAGE& m) {
Message::Type ty;
s >> ty;
- if (ty == m.type) {
- m.Serialize([&s](auto& value) { s >> value; });
- } else {
- std::stringstream ss;
- ss << "expected message type " << static_cast<int>(m.type) << ", got "
- << static_cast<int>(ty);
- s.error = ss.str();
+ if (s.error.empty()) {
+ if (ty == m.type) {
+ m.Serialize([&s](auto& value) { s >> value; });
+ } else {
+ std::stringstream ss;
+ ss << "expected message type " << static_cast<int>(m.type) << ", got "
+ << static_cast<int>(ty);
+ s.error = ss.str();
+ }
}
return s;
}
@@ -291,8 +295,7 @@
continue;
}
- // xcrun flags are ignored so this executable can be used as a replacement
- // for xcrun.
+ // xcrun flags are ignored so this executable can be used as a replacement for xcrun.
if ((arg == "-x" || arg == "-sdk") && (i < argc - 1)) {
i++;
continue;
@@ -357,7 +360,7 @@
ConnectionRequest req;
stream >> req;
if (!stream.error.empty()) {
- printf("%s\n", stream.error.c_str());
+ DEBUG("%s", stream.error.c_str());
return;
}
ConnectionResponse resp;
@@ -374,7 +377,7 @@
CompileRequest req;
stream >> req;
if (!stream.error.empty()) {
- printf("%s\n", stream.error.c_str());
+ DEBUG("%s\n", stream.error.c_str());
return;
}
#ifdef TINT_ENABLE_MSL_COMPILATION_USING_METAL_API
diff --git a/tools/src/cmd/remote-compile/rwmutex.h b/tools/src/cmd/remote-compile/rwmutex.h
index 6970ff3..0a6381d 100644
--- a/tools/src/cmd/remote-compile/rwmutex.h
+++ b/tools/src/cmd/remote-compile/rwmutex.h
@@ -29,56 +29,56 @@
public:
inline RWMutex() = default;
- /// lockReader() locks the mutex for reading.
+ /// LockReader() locks the mutex for reading.
/// Multiple read locks can be held while there are no writer locks.
- inline void lockReader();
+ inline void LockReader();
- /// unlockReader() unlocks the mutex for reading.
- inline void unlockReader();
+ /// UnlockReader() unlocks the mutex for reading.
+ inline void UnlockReader();
- /// lockWriter() locks the mutex for writing.
- /// If the lock is already locked for reading or writing, lockWriter blocks
+ /// LockWriter() locks the mutex for writing.
+ /// If the lock is already locked for reading or writing, LockWriter blocks
/// until the lock is available.
- inline void lockWriter();
+ inline void LockWriter();
- /// unlockWriter() unlocks the mutex for writing.
- inline void unlockWriter();
+ /// UnlockWriter() unlocks the mutex for writing.
+ inline void UnlockWriter();
private:
RWMutex(const RWMutex&) = delete;
RWMutex& operator=(const RWMutex&) = delete;
- int readLocks = 0;
- int pendingWriteLocks = 0;
+ int read_locks = 0;
+ int pending_write_locks = 0;
std::mutex mutex;
std::condition_variable cv;
};
-void RWMutex::lockReader() {
+void RWMutex::LockReader() {
std::unique_lock<std::mutex> lock(mutex);
- readLocks++;
+ read_locks++;
}
-void RWMutex::unlockReader() {
+void RWMutex::UnlockReader() {
std::unique_lock<std::mutex> lock(mutex);
- readLocks--;
- if (readLocks == 0 && pendingWriteLocks > 0) {
+ read_locks--;
+ if (read_locks == 0 && pending_write_locks > 0) {
cv.notify_one();
}
}
-void RWMutex::lockWriter() {
+void RWMutex::LockWriter() {
std::unique_lock<std::mutex> lock(mutex);
- if (readLocks > 0) {
- pendingWriteLocks++;
- cv.wait(lock, [&] { return readLocks == 0; });
- pendingWriteLocks--;
+ if (read_locks > 0) {
+ pending_write_locks++;
+ cv.wait(lock, [&] { return read_locks == 0; });
+ pending_write_locks--;
}
lock.release(); // Keep lock held
}
-void RWMutex::unlockWriter() {
- if (pendingWriteLocks > 0) {
+void RWMutex::UnlockWriter() {
+ if (pending_write_locks > 0) {
cv.notify_one();
}
mutex.unlock();
@@ -115,12 +115,12 @@
};
RLock::RLock(RWMutex& mutex) : m(&mutex) {
- m->lockReader();
+ m->LockReader();
}
RLock::~RLock() {
if (m != nullptr) {
- m->unlockReader();
+ m->UnlockReader();
}
}
@@ -167,12 +167,12 @@
};
WLock::WLock(RWMutex& mutex) : m(&mutex) {
- m->lockWriter();
+ m->LockWriter();
}
WLock::~WLock() {
if (m != nullptr) {
- m->unlockWriter();
+ m->UnlockWriter();
}
}
diff --git a/tools/src/cmd/remote-compile/socket.cc b/tools/src/cmd/remote-compile/socket.cc
index a16940a..91e87a0 100644
--- a/tools/src/cmd/remote-compile/socket.cc
+++ b/tools/src/cmd/remote-compile/socket.cc
@@ -32,7 +32,7 @@
#if defined(_WIN32)
#include <atomic>
namespace {
-std::atomic<int> wsaInitCount = {0};
+std::atomic<int> wsa_init_count = {0};
} // anonymous namespace
#else
#include <fcntl.h>
@@ -43,24 +43,24 @@
namespace {
constexpr SOCKET InvalidSocket = static_cast<SOCKET>(-1);
-void init() {
+void Init() {
#if defined(_WIN32)
- if (wsaInitCount++ == 0) {
- WSADATA winsockData;
- (void)WSAStartup(MAKEWORD(2, 2), &winsockData);
+ if (wsa_init_count++ == 0) {
+ WSADATA winsock_data;
+ (void)WSAStartup(MAKEWORD(2, 2), &winsock_data);
}
#endif
}
-void term() {
+void Term() {
#if defined(_WIN32)
- if (--wsaInitCount == 0) {
+ if (--wsa_init_count == 0) {
WSACleanup();
}
#endif
}
-bool setBlocking(SOCKET s, bool blocking) {
+bool SetBlocking(SOCKET s, bool blocking) {
#if defined(_WIN32)
u_long mode = blocking ? 0 : 1;
return ioctlsocket(s, FIONBIO, &mode) == NO_ERROR;
@@ -74,7 +74,7 @@
#endif
}
-bool errored(SOCKET s) {
+bool Errored(SOCKET s) {
if (s == InvalidSocket) {
return true;
}
@@ -87,7 +87,7 @@
class Impl : public Socket {
public:
static std::shared_ptr<Impl> create(const char* address, const char* port) {
- init();
+ Init();
addrinfo hints = {};
hints.ai_family = AF_INET;
@@ -106,12 +106,12 @@
if (info) {
auto socket = ::socket(info->ai_family, info->ai_socktype, info->ai_protocol);
auto out = std::make_shared<Impl>(info, socket);
- out->setOptions();
+ out->SetOptions();
return out;
}
freeaddrinfo(info);
- term();
+ Term();
return nullptr;
}
@@ -121,16 +121,16 @@
~Impl() {
freeaddrinfo(info);
Close();
- term();
+ Term();
}
template <typename FUNCTION>
- void lock(FUNCTION&& f) {
+ void Lock(FUNCTION&& f) {
RLock l(mutex);
f(s, info);
}
- void setOptions() {
+ void SetOptions() {
RLock l(mutex);
if (s == InvalidSocket) {
return;
@@ -157,7 +157,7 @@
bool IsOpen() override {
{
RLock l(mutex);
- if ((s != InvalidSocket) && !errored(s)) {
+ if ((s != InvalidSocket) && !Errored(s)) {
return true;
}
}
@@ -188,12 +188,20 @@
}
size_t Read(void* buffer, size_t bytes) override {
- RLock lock(mutex);
- if (s == InvalidSocket) {
- return 0;
+ {
+ RLock lock(mutex);
+ if (s == InvalidSocket) {
+ return 0;
+ }
+ size_t len = recv(s, reinterpret_cast<char*>(buffer), static_cast<int>(bytes), 0);
+ if (len > 0) {
+ return len;
+ }
}
- auto len = recv(s, reinterpret_cast<char*>(buffer), static_cast<int>(bytes), 0);
- return (len < 0) ? 0 : len;
+ // Socket closed or errored
+ WLock lock(mutex);
+ s = InvalidSocket;
+ return 0;
}
bool Write(const void* buffer, size_t bytes) override {
@@ -209,11 +217,13 @@
std::shared_ptr<Socket> Accept() override {
std::shared_ptr<Impl> out;
- lock([&](SOCKET socket, const addrinfo*) {
+ Lock([&](SOCKET socket, const addrinfo*) {
if (socket != InvalidSocket) {
- init();
- out = std::make_shared<Impl>(::accept(socket, 0, 0));
- out->setOptions();
+ Init();
+ if (auto s = ::accept(socket, 0, 0); s >= 0) {
+ out = std::make_shared<Impl>(s);
+ out->SetOptions();
+ }
}
});
return out;
@@ -232,7 +242,7 @@
if (!impl) {
return nullptr;
}
- impl->lock([&](SOCKET socket, const addrinfo* info) {
+ impl->Lock([&](SOCKET socket, const addrinfo* info) {
if (bind(socket, info->ai_addr, static_cast<int>(info->ai_addrlen)) != 0) {
impl.reset();
return;
@@ -248,46 +258,46 @@
std::shared_ptr<Socket> Socket::Connect(const char* address,
const char* port,
- uint32_t timeoutMillis) {
+ uint32_t timeout_ms) {
auto impl = Impl::create(address, port);
if (!impl) {
return nullptr;
}
std::shared_ptr<Socket> out;
- impl->lock([&](SOCKET socket, const addrinfo* info) {
+ impl->Lock([&](SOCKET socket, const addrinfo* info) {
if (socket == InvalidSocket) {
return;
}
- if (timeoutMillis == 0) {
+ if (timeout_ms == 0) {
if (::connect(socket, info->ai_addr, static_cast<int>(info->ai_addrlen)) == 0) {
out = impl;
}
return;
}
- if (!setBlocking(socket, false)) {
+ if (!SetBlocking(socket, false)) {
return;
}
auto res = ::connect(socket, info->ai_addr, static_cast<int>(info->ai_addrlen));
if (res == 0) {
- if (setBlocking(socket, true)) {
+ if (SetBlocking(socket, true)) {
out = impl;
}
} else {
- const auto microseconds = timeoutMillis * 1000;
+ const auto timeout_us = timeout_ms * 1000;
fd_set fdset;
FD_ZERO(&fdset);
FD_SET(socket, &fdset);
timeval tv;
- tv.tv_sec = microseconds / 1000000;
- tv.tv_usec = microseconds - static_cast<uint32_t>(tv.tv_sec * 1000000);
+ tv.tv_sec = timeout_us / 1000000;
+ tv.tv_usec = timeout_us - static_cast<uint32_t>(tv.tv_sec * 1000000);
res = select(static_cast<int>(socket + 1), nullptr, &fdset, nullptr, &tv);
- if (res > 0 && !errored(socket) && setBlocking(socket, true)) {
+ if (res > 0 && !Errored(socket) && SetBlocking(socket, true)) {
out = impl;
}
}
diff --git a/tools/src/cmd/remote-compile/socket.h b/tools/src/cmd/remote-compile/socket.h
index 514ebac..59115ea 100644
--- a/tools/src/cmd/remote-compile/socket.h
+++ b/tools/src/cmd/remote-compile/socket.h
@@ -24,29 +24,27 @@
/// Connects to the given TCP address and port.
/// @param address the target socket address
/// @param port the target socket port
- /// @param timeoutMillis the timeout for the connection attempt.
- /// If timeoutMillis is non-zero and no connection was made before
- /// timeoutMillis milliseconds, then nullptr is returned.
+ /// @param timeout_ms the timeout for the connection attempt.
+ /// If timeout_ms is non-zero and no connection was made before timeout_ms milliseconds,
+ /// then nullptr is returned.
/// @returns the connected Socket, or nullptr on failure
static std::shared_ptr<Socket> Connect(const char* address,
const char* port,
- uint32_t timeoutMillis);
+ uint32_t timeout_ms);
/// Begins listening for connections on the given TCP address and port.
/// Call Accept() on the returned Socket to block and wait for a connection.
- /// @param address the socket address to listen on. Use "localhost" for
- /// connections from only this machine, or an empty string to allow
- /// connections from any incoming address.
+ /// @param address the socket address to listen on. Use "localhost" for connections from only
+ /// this machine, or an empty string to allow connections from any incoming address.
/// @param port the socket port to listen on
/// @returns the Socket that listens for connections
static std::shared_ptr<Socket> Listen(const char* address, const char* port);
- /// Attempts to read at most `n` bytes into buffer, returning the actual
- /// number of bytes read.
+ /// Attempts to read at most `n` bytes into buffer, returning the actual number of bytes read.
/// read() will block until the socket is closed or at least one byte is read.
/// @param buffer the output buffer. Must be at least `n` bytes in size.
/// @param n the maximum number of bytes to read
- /// @return the number of bytes read, or 0 if the socket was closed
+ /// @return the number of bytes read, or 0 if the socket was closed or errored
virtual size_t Read(void* buffer, size_t n) = 0;
/// Writes `n` bytes from buffer into the socket.
@@ -62,8 +60,7 @@
/// Closes the socket.
virtual void Close() = 0;
- /// Blocks for a connection to be made to the listening port, or for the
- /// Socket to be closed.
+ /// Blocks for a connection to be made to the listening port, or for the Socket to be closed.
/// @returns a pointer to the next established incoming connection
virtual std::shared_ptr<Socket> Accept() = 0;
};