| // Copyright 2021 The Tint Authors |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include <stdio.h> |
| #include <stdlib.h> |
| #include <fstream> |
| #include <sstream> |
| #include <string> |
| #include <thread> |
| #include <type_traits> |
| #include <vector> |
| |
| #include "tools/src/cmd/remote-compile/compile.h" |
| #include "tools/src/cmd/remote-compile/socket.h" |
| |
| namespace { |
| |
| #if 0 |
| #define DEBUG(msg, ...) printf(msg "\n", ##__VA_ARGS__) |
| #else |
| #define DEBUG(...) |
| #endif |
| |
| /// Print the tool usage, and exit with 1. |
| void ShowUsage() { |
| const char* name = "tint-remote-compile"; |
| printf(R"(%s is a tool for compiling a shader on a remote machine |
| |
| usage as server: |
| %s -s [-p port-number] |
| |
| usage as client: |
| %s [-p port-number] [server-address] shader-file-path |
| |
| [server-address] can be omitted if the TINT_REMOTE_COMPILE_ADDRESS environment |
| variable is set. |
| Alternatively, you can pass xcrun arguments so %s can be used as a |
| drop-in replacement. |
| )", |
| name, name, name, name); |
| exit(1); |
| } |
| |
| /// The protocol version code. Bump each time the protocol changes |
| constexpr uint32_t kProtocolVersion = 1; |
| |
| /// Supported shader source languages |
| enum SourceLanguage { |
| MSL, |
| }; |
| |
| /// Stream is a serialization wrapper around a socket |
| struct Stream { |
| /// The underlying socket |
| Socket* const socket; |
| /// Error state |
| std::string error; |
| |
| /// Writes a uint32_t to the socket |
| Stream operator<<(uint32_t v) { |
| if (error.empty()) { |
| Write(&v, sizeof(v)); |
| } |
| return *this; |
| } |
| |
| /// Reads a uint32_t from the socket |
| Stream operator>>(uint32_t& v) { |
| if (error.empty()) { |
| Read(&v, sizeof(v)); |
| } |
| return *this; |
| } |
| |
| /// Writes a std::string to the socket |
| Stream operator<<(const std::string& v) { |
| if (error.empty()) { |
| uint32_t count = static_cast<uint32_t>(v.size()); |
| *this << count; |
| if (count) { |
| Write(v.data(), count); |
| } |
| } |
| return *this; |
| } |
| |
| /// Reads a std::string from the socket |
| Stream operator>>(std::string& v) { |
| uint32_t count = 0; |
| *this >> count; |
| if (count) { |
| std::vector<char> buf(count); |
| if (Read(buf.data(), count)) { |
| v = std::string(buf.data(), buf.size()); |
| } |
| } else { |
| v.clear(); |
| } |
| return *this; |
| } |
| |
| /// Writes an enum value to the socket |
| template <typename T> |
| std::enable_if_t<std::is_enum<T>::value, Stream> operator<<(T e) { |
| return *this << static_cast<uint32_t>(e); |
| } |
| |
| /// Reads an enum value from the socket |
| template <typename T> |
| std::enable_if_t<std::is_enum<T>::value, Stream> operator>>(T& e) { |
| uint32_t v; |
| *this >> v; |
| e = static_cast<T>(v); |
| return *this; |
| } |
| |
| private: |
| bool Write(const void* data, size_t size) { |
| if (error.empty()) { |
| if (!socket->Write(data, size)) { |
| error = "Socket::Write() failed"; |
| } |
| } |
| return error.empty(); |
| } |
| |
| bool Read(void* data, size_t size) { |
| auto buf = reinterpret_cast<uint8_t*>(data); |
| while (size > 0 && error.empty()) { |
| if (auto n = socket->Read(buf, size)) { |
| if (n > size) { |
| error = "Socket::Read() returned more bytes than requested"; |
| return false; |
| } |
| size -= n; |
| buf += n; |
| } |
| } |
| return error.empty(); |
| } |
| }; |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| // Messages |
| //////////////////////////////////////////////////////////////////////////////// |
| |
| /// Base class for all messages |
| struct Message { |
| /// The type of the message |
| enum class Type { |
| ConnectionRequest, |
| ConnectionResponse, |
| CompileRequest, |
| CompileResponse, |
| }; |
| |
| explicit Message(Type ty) : type(ty) {} |
| |
| const Type type; |
| }; |
| |
| struct ConnectionResponse : Message { // Server -> Client |
| ConnectionResponse() : Message(Type::ConnectionResponse) {} |
| |
| template <typename T> |
| void Serialize(T&& f) { |
| f(error); |
| } |
| |
| std::string error; |
| }; |
| |
| struct ConnectionRequest : Message { // Client -> Server |
| using Response = ConnectionResponse; |
| |
| explicit ConnectionRequest(uint32_t proto_ver = kProtocolVersion) |
| : Message(Type::ConnectionRequest), protocol_version(proto_ver) {} |
| |
| template <typename T> |
| void Serialize(T&& f) { |
| f(protocol_version); |
| } |
| |
| uint32_t protocol_version; |
| }; |
| |
| struct CompileResponse : Message { // Server -> Client |
| CompileResponse() : Message(Type::CompileResponse) {} |
| |
| template <typename T> |
| void Serialize(T&& f) { |
| f(error); |
| } |
| |
| std::string error; |
| }; |
| |
| struct CompileRequest : Message { // Client -> Server |
| using Response = CompileResponse; |
| |
| CompileRequest() : Message(Type::CompileRequest) {} |
| CompileRequest(SourceLanguage lang, std::string src) |
| : Message(Type::CompileRequest), language(lang), source(src) {} |
| |
| template <typename T> |
| void Serialize(T&& f) { |
| f(language); |
| f(source); |
| } |
| |
| SourceLanguage language; |
| std::string source; |
| }; |
| |
| /// Writes the message `m` to the stream `s` |
| template <typename MESSAGE> |
| std::enable_if_t<std::is_base_of<Message, MESSAGE>::value, Stream>& operator<<( |
| Stream& s, |
| const MESSAGE& m) { |
| s << m.type; |
| const_cast<MESSAGE&>(m).Serialize([&s](const auto& value) { s << value; }); |
| return s; |
| } |
| |
| /// Reads the message `m` from the stream `s` |
| template <typename MESSAGE> |
| std::enable_if_t<std::is_base_of<Message, MESSAGE>::value, Stream>& operator>>( |
| Stream& s, |
| MESSAGE& m) { |
| Message::Type ty; |
| s >> ty; |
| if (ty == m.type) { |
| m.Serialize([&s](auto& value) { s >> value; }); |
| } else { |
| std::stringstream ss; |
| ss << "expected message type " << static_cast<int>(m.type) << ", got " |
| << static_cast<int>(ty); |
| s.error = ss.str(); |
| } |
| return s; |
| } |
| |
| /// Writes the request message `req` to the stream `s`, then reads and returns |
| /// the response message from the same stream. |
| template <typename REQUEST, typename RESPONSE = typename REQUEST::Response> |
| RESPONSE Send(Stream& s, const REQUEST& req) { |
| s << req; |
| if (s.error.empty()) { |
| RESPONSE resp; |
| s >> resp; |
| if (s.error.empty()) { |
| return resp; |
| } |
| } |
| return {}; |
| } |
| |
| } // namespace |
| |
| bool RunServer(std::string port); |
| bool RunClient(std::string address, std::string port, std::string file); |
| |
| int main(int argc, char* argv[]) { |
| bool run_server = false; |
| std::string port = "19000"; |
| |
| std::vector<std::string> args; |
| for (int i = 1; i < argc; i++) { |
| std::string arg = argv[i]; |
| if (arg == "-s" || arg == "--server") { |
| run_server = true; |
| continue; |
| } |
| if (arg == "-p" || arg == "--port") { |
| if (i < argc - 1) { |
| i++; |
| port = argv[i]; |
| } else { |
| printf("expected port number"); |
| exit(1); |
| } |
| continue; |
| } |
| |
| // xcrun flags are ignored so this executable can be used as a replacement |
| // for xcrun. |
| if ((arg == "-x" || arg == "-sdk") && (i < argc - 1)) { |
| i++; |
| continue; |
| } |
| if (arg == "metal") { |
| for (; i < argc; i++) { |
| if (std::string(argv[i]) == "-c") { |
| break; |
| } |
| } |
| continue; |
| } |
| |
| args.emplace_back(arg); |
| } |
| |
| bool success = false; |
| |
| if (run_server) { |
| success = RunServer(port); |
| } else { |
| std::string address; |
| std::string file; |
| switch (args.size()) { |
| case 1: |
| if (auto* addr = getenv("TINT_REMOTE_COMPILE_ADDRESS")) { |
| address = addr; |
| } |
| file = args[0]; |
| break; |
| case 2: |
| address = args[0]; |
| file = args[1]; |
| break; |
| } |
| if (address.empty() || file.empty()) { |
| ShowUsage(); |
| } |
| success = RunClient(address, port, file); |
| } |
| |
| if (!success) { |
| exit(1); |
| } |
| |
| return 0; |
| } |
| |
| bool RunServer(std::string port) { |
| auto server_socket = Socket::Listen("", port.c_str()); |
| if (!server_socket) { |
| printf("Failed to listen on port %s\n", port.c_str()); |
| return false; |
| } |
| printf("Listening on port %s...\n", port.c_str()); |
| while (auto conn = server_socket->Accept()) { |
| std::thread([=] { |
| DEBUG("Client connected..."); |
| Stream stream{conn.get()}; |
| |
| { |
| ConnectionRequest req; |
| stream >> req; |
| if (!stream.error.empty()) { |
| printf("%s\n", stream.error.c_str()); |
| return; |
| } |
| ConnectionResponse resp; |
| if (req.protocol_version != kProtocolVersion) { |
| DEBUG("Protocol version mismatch"); |
| resp.error = "Protocol version mismatch"; |
| stream << resp; |
| return; |
| } |
| stream << resp; |
| } |
| DEBUG("Connection established"); |
| { |
| CompileRequest req; |
| stream >> req; |
| if (!stream.error.empty()) { |
| printf("%s\n", stream.error.c_str()); |
| return; |
| } |
| #ifdef TINT_ENABLE_MSL_COMPILATION_USING_METAL_API |
| if (req.language == SourceLanguage::MSL) { |
| auto result = CompileMslUsingMetalAPI(req.source); |
| CompileResponse resp; |
| if (!result.success) { |
| resp.error = result.output; |
| } |
| stream << resp; |
| return; |
| } |
| #endif |
| CompileResponse resp; |
| resp.error = "server cannot compile this type of shader"; |
| stream << resp; |
| } |
| }).detach(); |
| } |
| return true; |
| } |
| |
| bool RunClient(std::string address, std::string port, std::string file) { |
| // Read the file |
| std::ifstream input(file, std::ios::binary); |
| if (!input) { |
| printf("Couldn't open '%s'\n", file.c_str()); |
| return false; |
| } |
| std::string source((std::istreambuf_iterator<char>(input)), |
| std::istreambuf_iterator<char>()); |
| |
| constexpr const int timeout_ms = 10000; |
| DEBUG("Connecting to %s:%s...", address.c_str(), port.c_str()); |
| auto conn = Socket::Connect(address.c_str(), port.c_str(), timeout_ms); |
| if (!conn) { |
| printf("Connection failed\n"); |
| return false; |
| } |
| |
| Stream stream{conn.get()}; |
| |
| DEBUG("Sending connection request..."); |
| auto conn_resp = Send(stream, ConnectionRequest{kProtocolVersion}); |
| if (!stream.error.empty()) { |
| printf("%s\n", stream.error.c_str()); |
| return false; |
| } |
| if (!conn_resp.error.empty()) { |
| printf("%s\n", conn_resp.error.c_str()); |
| return false; |
| } |
| DEBUG("Connection established. Requesting compile..."); |
| auto comp_resp = Send(stream, CompileRequest{SourceLanguage::MSL, source}); |
| if (!stream.error.empty()) { |
| printf("%s\n", stream.error.c_str()); |
| return false; |
| } |
| if (!comp_resp.error.empty()) { |
| printf("%s\n", comp_resp.error.c_str()); |
| return false; |
| } |
| DEBUG("Compilation successful"); |
| return true; |
| } |