tools/remote-compile: Handle socket disconnects

If the socket disconnected mid-communication, the server could spin, waiting for new data.

Actually handle recv() errors, preventing the server spinning itself to death.

Also fix code style to be more tint-like (snake_case variables, PascalCase functions)

Change-Id: I9fcbfde303a8624e7e1ff87abd33581589f4da42
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/105142
Auto-Submit: Ben Clayton <bclayton@google.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/tools/src/cmd/remote-compile/main.cc b/tools/src/cmd/remote-compile/main.cc
index 20fb911..53a7deb 100644
--- a/tools/src/cmd/remote-compile/main.cc
+++ b/tools/src/cmd/remote-compile/main.cc
@@ -145,6 +145,8 @@
                 }
                 size -= n;
                 buf += n;
+            } else {
+                error = "Socket::Read() failed";
             }
         }
         return error.empty();
@@ -238,13 +240,15 @@
                                                                                MESSAGE& m) {
     Message::Type ty;
     s >> ty;
-    if (ty == m.type) {
-        m.Serialize([&s](auto& value) { s >> value; });
-    } else {
-        std::stringstream ss;
-        ss << "expected message type " << static_cast<int>(m.type) << ", got "
-           << static_cast<int>(ty);
-        s.error = ss.str();
+    if (s.error.empty()) {
+        if (ty == m.type) {
+            m.Serialize([&s](auto& value) { s >> value; });
+        } else {
+            std::stringstream ss;
+            ss << "expected message type " << static_cast<int>(m.type) << ", got "
+               << static_cast<int>(ty);
+            s.error = ss.str();
+        }
     }
     return s;
 }
@@ -291,8 +295,7 @@
             continue;
         }
 
-        // xcrun flags are ignored so this executable can be used as a replacement
-        // for xcrun.
+        // xcrun flags are ignored so this executable can be used as a replacement for xcrun.
         if ((arg == "-x" || arg == "-sdk") && (i < argc - 1)) {
             i++;
             continue;
@@ -357,7 +360,7 @@
                 ConnectionRequest req;
                 stream >> req;
                 if (!stream.error.empty()) {
-                    printf("%s\n", stream.error.c_str());
+                    DEBUG("%s", stream.error.c_str());
                     return;
                 }
                 ConnectionResponse resp;
@@ -374,7 +377,7 @@
                 CompileRequest req;
                 stream >> req;
                 if (!stream.error.empty()) {
-                    printf("%s\n", stream.error.c_str());
+                    DEBUG("%s\n", stream.error.c_str());
                     return;
                 }
 #ifdef TINT_ENABLE_MSL_COMPILATION_USING_METAL_API
diff --git a/tools/src/cmd/remote-compile/rwmutex.h b/tools/src/cmd/remote-compile/rwmutex.h
index 6970ff3..0a6381d 100644
--- a/tools/src/cmd/remote-compile/rwmutex.h
+++ b/tools/src/cmd/remote-compile/rwmutex.h
@@ -29,56 +29,56 @@
   public:
     inline RWMutex() = default;
 
-    /// lockReader() locks the mutex for reading.
+    /// LockReader() locks the mutex for reading.
     /// Multiple read locks can be held while there are no writer locks.
-    inline void lockReader();
+    inline void LockReader();
 
-    /// unlockReader() unlocks the mutex for reading.
-    inline void unlockReader();
+    /// UnlockReader() unlocks the mutex for reading.
+    inline void UnlockReader();
 
-    /// lockWriter() locks the mutex for writing.
-    /// If the lock is already locked for reading or writing, lockWriter blocks
+    /// LockWriter() locks the mutex for writing.
+    /// If the lock is already locked for reading or writing, LockWriter blocks
     /// until the lock is available.
-    inline void lockWriter();
+    inline void LockWriter();
 
-    /// unlockWriter() unlocks the mutex for writing.
-    inline void unlockWriter();
+    /// UnlockWriter() unlocks the mutex for writing.
+    inline void UnlockWriter();
 
   private:
     RWMutex(const RWMutex&) = delete;
     RWMutex& operator=(const RWMutex&) = delete;
 
-    int readLocks = 0;
-    int pendingWriteLocks = 0;
+    int read_locks = 0;
+    int pending_write_locks = 0;
     std::mutex mutex;
     std::condition_variable cv;
 };
 
-void RWMutex::lockReader() {
+void RWMutex::LockReader() {
     std::unique_lock<std::mutex> lock(mutex);
-    readLocks++;
+    read_locks++;
 }
 
-void RWMutex::unlockReader() {
+void RWMutex::UnlockReader() {
     std::unique_lock<std::mutex> lock(mutex);
-    readLocks--;
-    if (readLocks == 0 && pendingWriteLocks > 0) {
+    read_locks--;
+    if (read_locks == 0 && pending_write_locks > 0) {
         cv.notify_one();
     }
 }
 
-void RWMutex::lockWriter() {
+void RWMutex::LockWriter() {
     std::unique_lock<std::mutex> lock(mutex);
-    if (readLocks > 0) {
-        pendingWriteLocks++;
-        cv.wait(lock, [&] { return readLocks == 0; });
-        pendingWriteLocks--;
+    if (read_locks > 0) {
+        pending_write_locks++;
+        cv.wait(lock, [&] { return read_locks == 0; });
+        pending_write_locks--;
     }
     lock.release();  // Keep lock held
 }
 
-void RWMutex::unlockWriter() {
-    if (pendingWriteLocks > 0) {
+void RWMutex::UnlockWriter() {
+    if (pending_write_locks > 0) {
         cv.notify_one();
     }
     mutex.unlock();
@@ -115,12 +115,12 @@
 };
 
 RLock::RLock(RWMutex& mutex) : m(&mutex) {
-    m->lockReader();
+    m->LockReader();
 }
 
 RLock::~RLock() {
     if (m != nullptr) {
-        m->unlockReader();
+        m->UnlockReader();
     }
 }
 
@@ -167,12 +167,12 @@
 };
 
 WLock::WLock(RWMutex& mutex) : m(&mutex) {
-    m->lockWriter();
+    m->LockWriter();
 }
 
 WLock::~WLock() {
     if (m != nullptr) {
-        m->unlockWriter();
+        m->UnlockWriter();
     }
 }
 
diff --git a/tools/src/cmd/remote-compile/socket.cc b/tools/src/cmd/remote-compile/socket.cc
index a16940a..91e87a0 100644
--- a/tools/src/cmd/remote-compile/socket.cc
+++ b/tools/src/cmd/remote-compile/socket.cc
@@ -32,7 +32,7 @@
 #if defined(_WIN32)
 #include <atomic>
 namespace {
-std::atomic<int> wsaInitCount = {0};
+std::atomic<int> wsa_init_count = {0};
 }  // anonymous namespace
 #else
 #include <fcntl.h>
@@ -43,24 +43,24 @@
 
 namespace {
 constexpr SOCKET InvalidSocket = static_cast<SOCKET>(-1);
-void init() {
+void Init() {
 #if defined(_WIN32)
-    if (wsaInitCount++ == 0) {
-        WSADATA winsockData;
-        (void)WSAStartup(MAKEWORD(2, 2), &winsockData);
+    if (wsa_init_count++ == 0) {
+        WSADATA winsock_data;
+        (void)WSAStartup(MAKEWORD(2, 2), &winsock_data);
     }
 #endif
 }
 
-void term() {
+void Term() {
 #if defined(_WIN32)
-    if (--wsaInitCount == 0) {
+    if (--wsa_init_count == 0) {
         WSACleanup();
     }
 #endif
 }
 
-bool setBlocking(SOCKET s, bool blocking) {
+bool SetBlocking(SOCKET s, bool blocking) {
 #if defined(_WIN32)
     u_long mode = blocking ? 0 : 1;
     return ioctlsocket(s, FIONBIO, &mode) == NO_ERROR;
@@ -74,7 +74,7 @@
 #endif
 }
 
-bool errored(SOCKET s) {
+bool Errored(SOCKET s) {
     if (s == InvalidSocket) {
         return true;
     }
@@ -87,7 +87,7 @@
 class Impl : public Socket {
   public:
     static std::shared_ptr<Impl> create(const char* address, const char* port) {
-        init();
+        Init();
 
         addrinfo hints = {};
         hints.ai_family = AF_INET;
@@ -106,12 +106,12 @@
         if (info) {
             auto socket = ::socket(info->ai_family, info->ai_socktype, info->ai_protocol);
             auto out = std::make_shared<Impl>(info, socket);
-            out->setOptions();
+            out->SetOptions();
             return out;
         }
 
         freeaddrinfo(info);
-        term();
+        Term();
         return nullptr;
     }
 
@@ -121,16 +121,16 @@
     ~Impl() {
         freeaddrinfo(info);
         Close();
-        term();
+        Term();
     }
 
     template <typename FUNCTION>
-    void lock(FUNCTION&& f) {
+    void Lock(FUNCTION&& f) {
         RLock l(mutex);
         f(s, info);
     }
 
-    void setOptions() {
+    void SetOptions() {
         RLock l(mutex);
         if (s == InvalidSocket) {
             return;
@@ -157,7 +157,7 @@
     bool IsOpen() override {
         {
             RLock l(mutex);
-            if ((s != InvalidSocket) && !errored(s)) {
+            if ((s != InvalidSocket) && !Errored(s)) {
                 return true;
             }
         }
@@ -188,12 +188,20 @@
     }
 
     size_t Read(void* buffer, size_t bytes) override {
-        RLock lock(mutex);
-        if (s == InvalidSocket) {
-            return 0;
+        {
+            RLock lock(mutex);
+            if (s == InvalidSocket) {
+                return 0;
+            }
+            size_t len = recv(s, reinterpret_cast<char*>(buffer), static_cast<int>(bytes), 0);
+            if (len > 0) {
+                return len;
+            }
         }
-        auto len = recv(s, reinterpret_cast<char*>(buffer), static_cast<int>(bytes), 0);
-        return (len < 0) ? 0 : len;
+        // Socket closed or errored
+        WLock lock(mutex);
+        s = InvalidSocket;
+        return 0;
     }
 
     bool Write(const void* buffer, size_t bytes) override {
@@ -209,11 +217,13 @@
 
     std::shared_ptr<Socket> Accept() override {
         std::shared_ptr<Impl> out;
-        lock([&](SOCKET socket, const addrinfo*) {
+        Lock([&](SOCKET socket, const addrinfo*) {
             if (socket != InvalidSocket) {
-                init();
-                out = std::make_shared<Impl>(::accept(socket, 0, 0));
-                out->setOptions();
+                Init();
+                if (auto s = ::accept(socket, 0, 0); s >= 0) {
+                    out = std::make_shared<Impl>(s);
+                    out->SetOptions();
+                }
             }
         });
         return out;
@@ -232,7 +242,7 @@
     if (!impl) {
         return nullptr;
     }
-    impl->lock([&](SOCKET socket, const addrinfo* info) {
+    impl->Lock([&](SOCKET socket, const addrinfo* info) {
         if (bind(socket, info->ai_addr, static_cast<int>(info->ai_addrlen)) != 0) {
             impl.reset();
             return;
@@ -248,46 +258,46 @@
 
 std::shared_ptr<Socket> Socket::Connect(const char* address,
                                         const char* port,
-                                        uint32_t timeoutMillis) {
+                                        uint32_t timeout_ms) {
     auto impl = Impl::create(address, port);
     if (!impl) {
         return nullptr;
     }
 
     std::shared_ptr<Socket> out;
-    impl->lock([&](SOCKET socket, const addrinfo* info) {
+    impl->Lock([&](SOCKET socket, const addrinfo* info) {
         if (socket == InvalidSocket) {
             return;
         }
 
-        if (timeoutMillis == 0) {
+        if (timeout_ms == 0) {
             if (::connect(socket, info->ai_addr, static_cast<int>(info->ai_addrlen)) == 0) {
                 out = impl;
             }
             return;
         }
 
-        if (!setBlocking(socket, false)) {
+        if (!SetBlocking(socket, false)) {
             return;
         }
 
         auto res = ::connect(socket, info->ai_addr, static_cast<int>(info->ai_addrlen));
         if (res == 0) {
-            if (setBlocking(socket, true)) {
+            if (SetBlocking(socket, true)) {
                 out = impl;
             }
         } else {
-            const auto microseconds = timeoutMillis * 1000;
+            const auto timeout_us = timeout_ms * 1000;
 
             fd_set fdset;
             FD_ZERO(&fdset);
             FD_SET(socket, &fdset);
 
             timeval tv;
-            tv.tv_sec = microseconds / 1000000;
-            tv.tv_usec = microseconds - static_cast<uint32_t>(tv.tv_sec * 1000000);
+            tv.tv_sec = timeout_us / 1000000;
+            tv.tv_usec = timeout_us - static_cast<uint32_t>(tv.tv_sec * 1000000);
             res = select(static_cast<int>(socket + 1), nullptr, &fdset, nullptr, &tv);
-            if (res > 0 && !errored(socket) && setBlocking(socket, true)) {
+            if (res > 0 && !Errored(socket) && SetBlocking(socket, true)) {
                 out = impl;
             }
         }
diff --git a/tools/src/cmd/remote-compile/socket.h b/tools/src/cmd/remote-compile/socket.h
index 514ebac..59115ea 100644
--- a/tools/src/cmd/remote-compile/socket.h
+++ b/tools/src/cmd/remote-compile/socket.h
@@ -24,29 +24,27 @@
     /// Connects to the given TCP address and port.
     /// @param address the target socket address
     /// @param port the target socket port
-    /// @param timeoutMillis the timeout for the connection attempt.
-    ///        If timeoutMillis is non-zero and no connection was made before
-    ///        timeoutMillis milliseconds, then nullptr is returned.
+    /// @param timeout_ms the timeout for the connection attempt.
+    ///        If timeout_ms is non-zero and no connection was made before timeout_ms milliseconds,
+    ///        then nullptr is returned.
     /// @returns the connected Socket, or nullptr on failure
     static std::shared_ptr<Socket> Connect(const char* address,
                                            const char* port,
-                                           uint32_t timeoutMillis);
+                                           uint32_t timeout_ms);
 
     /// Begins listening for connections on the given TCP address and port.
     /// Call Accept() on the returned Socket to block and wait for a connection.
-    /// @param address the socket address to listen on. Use "localhost" for
-    ///        connections from only this machine, or an empty string to allow
-    ///        connections from any incoming address.
+    /// @param address the socket address to listen on. Use "localhost" for connections from only
+    ///        this machine, or an empty string to allow connections from any incoming address.
     /// @param port the socket port to listen on
     /// @returns the Socket that listens for connections
     static std::shared_ptr<Socket> Listen(const char* address, const char* port);
 
-    /// Attempts to read at most `n` bytes into buffer, returning the actual
-    /// number of bytes read.
+    /// Attempts to read at most `n` bytes into buffer, returning the actual number of bytes read.
     /// read() will block until the socket is closed or at least one byte is read.
     /// @param buffer the output buffer. Must be at least `n` bytes in size.
     /// @param n the maximum number of bytes to read
-    /// @return the number of bytes read, or 0 if the socket was closed
+    /// @return the number of bytes read, or 0 if the socket was closed or errored
     virtual size_t Read(void* buffer, size_t n) = 0;
 
     /// Writes `n` bytes from buffer into the socket.
@@ -62,8 +60,7 @@
     /// Closes the socket.
     virtual void Close() = 0;
 
-    /// Blocks for a connection to be made to the listening port, or for the
-    /// Socket to be closed.
+    /// Blocks for a connection to be made to the listening port, or for the Socket to be closed.
     /// @returns a pointer to the next established incoming connection
     virtual std::shared_ptr<Socket> Accept() = 0;
 };