From 2975facdd1ddb3b52d052aaaac611566f3ff27ae Mon Sep 17 00:00:00 2001 From: Ryan Ofsky Date: Wed, 30 Apr 2025 08:39:29 -0400 Subject: [PATCH] Add windows support Add support for running on windows. These changes make the libmultiprocess API more generic, using stream types instead of file descriptors. All features are supported, including spawning processes with socket connections to the parent process. These changes were originally made in https://github.com/bitcoin/bitcoin/pull/32387 --- example/calculator.cpp | 15 ++- example/example.cpp | 14 ++- example/printer.cpp | 16 ++- include/mp/proxy-io.h | 29 ++++-- include/mp/util.h | 46 +++++++-- src/mp/proxy.cpp | 88 +++++++++++++---- src/mp/util.cpp | 183 ++++++++++++++++++++++++++++------- test/mp/test/spawn_tests.cpp | 10 +- 8 files changed, 300 insertions(+), 101 deletions(-) diff --git a/example/calculator.cpp b/example/calculator.cpp index 86ce388b..8f74fade 100644 --- a/example/calculator.cpp +++ b/example/calculator.cpp @@ -6,19 +6,19 @@ #include #include // NOLINT(misc-include-cleaner) // IWYU pragma: keep -#include -#include +#include // IWYU pragma: keep #include #include #include #include +#include #include #include #include #include +#include #include #include -#include #include class CalculatorImpl : public Calculator @@ -51,14 +51,11 @@ int main(int argc, char** argv) std::cout << "Usage: mpcalculator \n"; return 1; } - int fd; - if (std::from_chars(argv[1], argv[1] + strlen(argv[1]), fd).ec != std::errc{}) { - std::cerr << argv[1] << " is not a number or is larger than an int\n"; - return 1; - } + mp::SocketId socket{mp::StartSpawned(argv[1])}; mp::EventLoop loop("mpcalculator", LogPrint); std::unique_ptr init = std::make_unique(); - mp::ServeStream(loop, fd, *init); + mp::Stream stream{loop.m_io_context.lowLevelProvider->wrapSocketFd(socket, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP)}; + mp::ServeStream(loop, kj::mv(stream), *init); loop.loop(); return 0; } diff --git a/example/example.cpp b/example/example.cpp index 38313977..62e62290 100644 --- a/example/example.cpp +++ b/example/example.cpp @@ -5,13 +5,16 @@ #include #include +#include #include // IWYU pragma: keep #include #include #include #include #include +#include #include +#include #include #include #include @@ -19,20 +22,21 @@ #include #include #include +#include #include namespace fs = std::filesystem; static auto Spawn(mp::EventLoop& loop, const std::string& process_argv0, const std::string& new_exe_name) { - int pid; - const int fd = mp::SpawnProcess(pid, [&](int fd) -> std::vector { + auto pair{mp::SocketPair()}; + mp::ProcessId pid{mp::SpawnProcess(pair[0], [&](mp::ConnectInfo info) -> std::vector { fs::path path = process_argv0; path.remove_filename(); path.append(new_exe_name); - return {path.string(), std::to_string(fd)}; - }); - return std::make_tuple(mp::ConnectStream(loop, fd), pid); + return {path.string(), std::move(info)}; + })}; + return std::make_tuple(mp::ConnectStream(loop, loop.m_io_context.lowLevelProvider->wrapSocketFd(pair[1])), pid); } static void LogPrint(mp::LogMessage log_data) diff --git a/example/printer.cpp b/example/printer.cpp index 9150d59b..68f62ff3 100644 --- a/example/printer.cpp +++ b/example/printer.cpp @@ -7,18 +7,19 @@ #include #include // NOLINT(misc-include-cleaner) // IWYU pragma: keep -#include -#include +#include // IWYU pragma: keep #include #include #include +#include #include #include #include #include +#include #include #include -#include +#include class PrinterImpl : public Printer { @@ -44,14 +45,11 @@ int main(int argc, char** argv) std::cout << "Usage: mpprinter \n"; return 1; } - int fd; - if (std::from_chars(argv[1], argv[1] + strlen(argv[1]), fd).ec != std::errc{}) { - std::cerr << argv[1] << " is not a number or is larger than an int\n"; - return 1; - } + mp::SocketId socket{mp::StartSpawned(argv[1])}; mp::EventLoop loop("mpprinter", LogPrint); std::unique_ptr init = std::make_unique(); - mp::ServeStream(loop, fd, *init); + mp::Stream stream{loop.m_io_context.lowLevelProvider->wrapSocketFd(socket, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP)}; + mp::ServeStream(loop, std::move(stream), *init); loop.loop(); return 0; } diff --git a/include/mp/proxy-io.h b/include/mp/proxy-io.h index d7b9f0e5..42bf1c71 100644 --- a/include/mp/proxy-io.h +++ b/include/mp/proxy-io.h @@ -210,6 +210,17 @@ class Logger std::string LongThreadName(const char* exe_name); +using Stream = kj::Own; + +inline SocketId StreamSocketId(const Stream& stream) +{ + if (stream) KJ_IF_MAYBE(fd, stream->getFd()) return *fd; +#ifdef WIN32 + if (stream) KJ_IF_MAYBE(handle, stream->getWin32Handle()) return reinterpret_cast(*handle); +#endif + throw std::logic_error("Stream socket unset"); +} + //! Event loop implementation. //! //! Cap'n Proto threading model is very simple: all I/O operations are @@ -308,11 +319,12 @@ class EventLoop //! Callback functions to run on async thread. std::optional m_async_fns MP_GUARDED_BY(m_mutex); - //! Pipe read handle used to wake up the event loop thread. - int m_wait_fd = -1; + //! Socket pair used to post and wait for wakeups to the event loop thread. + kj::Own m_wait_stream; + kj::Own m_post_stream; - //! Pipe write handle used to wake up the event loop thread. - int m_post_fd = -1; + //! Synchronous writer used to write to m_post_stream. + kj::Own m_post_writer; //! Number of clients holding references to ProxyServerBase objects that //! reference this event loop. @@ -797,13 +809,11 @@ kj::Promise ProxyServer::post(Fn&& fn) //! over the stream. Also create a new Connection object embedded in the //! client that is freed when the client is closed. template -std::unique_ptr> ConnectStream(EventLoop& loop, int fd) +std::unique_ptr> ConnectStream(EventLoop& loop, kj::Own stream) { typename InitInterface::Client init_client(nullptr); std::unique_ptr connection; loop.sync([&] { - auto stream = - loop.m_io_context.lowLevelProvider->wrapSocketFd(fd, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP); connection = std::make_unique(loop, kj::mv(stream)); init_client = connection->m_rpc_system->bootstrap(ServerVatId().vat_id).castAs(); Connection* connection_ptr = connection.get(); @@ -854,10 +864,9 @@ void _Listen(EventLoop& loop, kj::Own&& listener, InitIm //! Given stream file descriptor and an init object, handle requests on the //! stream by calling methods on the Init object. template -void ServeStream(EventLoop& loop, int fd, InitImpl& init) +void ServeStream(EventLoop& loop, kj::Own stream, InitImpl& init) { - _Serve( - loop, loop.m_io_context.lowLevelProvider->wrapSocketFd(fd, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP), init); + _Serve(loop, kj::mv(stream), init); } //! Given listening socket file descriptor and an init object, handle incoming diff --git a/include/mp/util.h b/include/mp/util.h index a3db1282..8f758e12 100644 --- a/include/mp/util.h +++ b/include/mp/util.h @@ -5,6 +5,7 @@ #ifndef MP_UTIL_H #define MP_UTIL_H +#include #include #include #include @@ -20,6 +21,10 @@ #include #include +#ifdef WIN32 +#include +#endif + namespace mp { //! Generic utility functions used by capnp code. @@ -249,17 +254,38 @@ std::string ThreadName(const char* exe_name); //! errors in python unit tests. std::string LogEscape(const kj::StringTree& string, size_t max_size); +#ifdef WIN32 +using ProcessId = uintptr_t; +using SocketId = uintptr_t; +constexpr SocketId SocketError{INVALID_SOCKET}; +#else +using ProcessId = int; +using SocketId = int; +constexpr SocketId SocketError{-1}; +#endif + +//! Information about parent process passed to child process. On unix this is +//! just the inherited int file descriptor formatted as a string. On windows, +//! this is a path to a named pipe the parent process will write +//! WSADuplicateSocket info to. +using ConnectInfo = std::string; + //! Callback type used by SpawnProcess below. -using FdToArgsFn = std::function(int fd)>; +using ConnectInfoToArgsFn = std::function(const ConnectInfo&)>; + +//! Create a socket pair that can be used to communicate within a process or +//! between parent and child processes. +std::array SocketPair(); + +//! Spawn a new process that communicates with the current process over provided +//! socket argument. Calls connect_info_to_args callback with a connection +//! string that needs to be passed to the child process, and executes the +//! argv command line it returns. Returns child process id. +ProcessId SpawnProcess(SocketId socket, ConnectInfoToArgsFn&& connect_info_to_args); -//! Spawn a new process that communicates with the current process over a socket -//! pair. Returns pid through an output argument, and file descriptor for the -//! local side of the socket. -//! The fd_to_args callback is invoked in the parent process before fork(). -//! It must not rely on child pid/state, and must return the command line -//! arguments that should be used to execute the process. Embed the remote file -//! descriptor number in whatever format the child process expects. -int SpawnProcess(int& pid, FdToArgsFn&& fd_to_args); +//! Initialize spawned child process using the ConnectInfo string passed to it, +//! returning a socket id for communicating with the parent process. +SocketId StartSpawned(const ConnectInfo& connect_info); //! Call execvp with vector args. //! Not safe to call in a post-fork child of a multi-threaded process. @@ -267,7 +293,7 @@ int SpawnProcess(int& pid, FdToArgsFn&& fd_to_args); void ExecProcess(const std::vector& args); //! Wait for a process to exit and return its exit code. -int WaitProcess(int pid); +int WaitProcess(ProcessId pid); inline char* CharCast(char* c) { return c; } inline char* CharCast(unsigned char* c) { return (char*)c; } diff --git a/src/mp/proxy.cpp b/src/mp/proxy.cpp index d24208db..95fe0564 100644 --- a/src/mp/proxy.cpp +++ b/src/mp/proxy.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -30,10 +31,8 @@ #include #include #include -#include #include #include -#include #include namespace mp { @@ -66,10 +65,9 @@ void EventLoopRef::reset(bool relock) MP_NO_TSA loop->m_num_clients -= 1; if (loop->done()) { loop->m_cv.notify_all(); - int post_fd{loop->m_post_fd}; loop_lock->unlock(); char buffer = 0; - KJ_SYSCALL(write(post_fd, &buffer, 1)); // NOLINT(bugprone-suspicious-semicolon) + loop->m_post_writer->write(&buffer, 1); // By default, do not try to relock `loop_lock` after writing, // because the event loop could wake up and destroy itself and the // mutex might no longer exist. @@ -100,6 +98,20 @@ Connection::~Connection() // after the calls finish. m_rpc_system.reset(); + // shutdownWrite is needed on Windows so pending data in the m_stream socket + // will be sent instead of discarded when m_stream is destroyed. On unix, + // this doesn't seem to be needed because data is sent more reliably. + // + // Sending pending data is important if the connection is a socketpair + // because when one side of the socketpair is closed, the other side doesn't + // seem to receive any onDisconnect event. So it is important for the other + // side to instead receive Cap'n Proto "release" messages (see `struct + // Release` in capnp/rpc.capnp) from local Client objects being destroyed so + // the remote side can free resources and shut down cleanly. Without this + // call, Server objects corresponding to the Client objects on the other + // side of the connection are not freed by Cap'n Proto. + m_stream->shutdownWrite(); + // ProxyClient cleanup handlers are in sync list, and ProxyServer cleanup // handlers are in the async list. // @@ -196,6 +208,40 @@ void EventLoop::addAsyncCleanup(std::function fn) startAsyncThread(); } +#ifdef WIN32 +//! Synchronous socket output stream. Cap'n Proto library only provides limited +//! support for synchronous IO. It provides `FdOutputStream` which wraps unix +//! file descriptors and calls write() internally, and `HandleOutStream` which +//! wraps windows HANDLE values and calls WriteFile() internally. This class +//! just provides analogous functionality wrapping SOCKET values and calls +//! send() internally. +class SocketOutputStream : public kj::OutputStream { +public: + explicit SocketOutputStream(SOCKET socket) : m_socket(socket) {} + + void write(const void* buffer, size_t size) override; + +private: + SOCKET m_socket; +}; + +static constexpr size_t WRITE_CLAMP_SIZE = 1u << 30; // 1GB clamp for Windows, like FdOutputStream + +void SocketOutputStream::write(const void* buffer, size_t size) { + const char* pos = reinterpret_cast(buffer); + + while (size > 0) { + int n = send(m_socket, pos, static_cast(kj::min(size, WRITE_CLAMP_SIZE)), 0); + + KJ_WIN32(n != SOCKET_ERROR, "send() failed"); + KJ_ASSERT(n > 0, "send() returned zero."); + + pos += n; + size -= n; + } +} +#endif + EventLoop::EventLoop(const char* exe_name, LogOptions log_opts, void* context) : m_exe_name(exe_name), m_io_context(kj::setupAsyncIo()), @@ -203,10 +249,18 @@ EventLoop::EventLoop(const char* exe_name, LogOptions log_opts, void* context) m_log_opts(std::move(log_opts)), m_context(context) { - int fds[2]; - KJ_SYSCALL(socketpair(AF_UNIX, SOCK_STREAM, 0, fds)); - m_wait_fd = fds[0]; - m_post_fd = fds[1]; + auto pipe = m_io_context.provider->newTwoWayPipe(); + m_wait_stream = kj::mv(pipe.ends[0]); + m_post_stream = kj::mv(pipe.ends[1]); + KJ_IF_MAYBE(fd, m_post_stream->getFd()) { + m_post_writer = kj::heap(*fd); +#ifdef WIN32 + } else KJ_IF_MAYBE(handle, m_post_stream->getWin32Handle()) { + m_post_writer = kj::heap(reinterpret_cast(*handle)); +#endif + } else { + throw std::logic_error("Could not get file descriptor for new pipe."); + } } EventLoop::~EventLoop() @@ -215,8 +269,8 @@ EventLoop::~EventLoop() const Lock lock(m_mutex); KJ_ASSERT(m_post_fn == nullptr); KJ_ASSERT(!m_async_fns); - KJ_ASSERT(m_wait_fd == -1); - KJ_ASSERT(m_post_fd == -1); + KJ_ASSERT(!m_wait_stream); + KJ_ASSERT(!m_post_stream); KJ_ASSERT(m_num_clients == 0); // Spin event loop. wait for any promises triggered by RPC shutdown. @@ -236,9 +290,7 @@ void EventLoop::loop() m_async_fns.emplace(); } - kj::Own wait_stream{ - m_io_context.lowLevelProvider->wrapSocketFd(m_wait_fd, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP)}; - int post_fd{m_post_fd}; + kj::Own& wait_stream{m_wait_stream}; char buffer = 0; for (;;) { const size_t read_bytes = wait_stream->read(&buffer, 0, 1).wait(m_io_context.waitScope); @@ -250,7 +302,7 @@ void EventLoop::loop() m_cv.notify_all(); } else if (done()) { // Intentionally do not break if m_post_fn was set, even if done() - // would return true, to ensure that the EventLoopRef write(post_fd) + // would return true, to ensure that the EventLoopRef write(post_stream) // call always succeeds and the loop does not exit between the time // that the done condition is set and the write call is made. break; @@ -260,10 +312,9 @@ void EventLoop::loop() m_task_set.reset(); MP_LOG(*this, Log::Info) << "EventLoop::loop bye."; wait_stream = nullptr; - KJ_SYSCALL(::close(post_fd)); const Lock lock(m_mutex); - m_wait_fd = -1; - m_post_fd = -1; + m_wait_stream = nullptr; + m_post_stream = nullptr; m_async_fns.reset(); m_cv.notify_all(); } @@ -278,10 +329,9 @@ void EventLoop::post(kj::Function fn) EventLoopRef ref(*this, &lock); m_cv.wait(lock.m_lock, [this]() MP_REQUIRES(m_mutex) { return m_post_fn == nullptr; }); m_post_fn = &fn; - int post_fd{m_post_fd}; Unlock(lock, [&] { char buffer = 0; - KJ_SYSCALL(write(post_fd, &buffer, 1)); + m_post_writer->write(&buffer, 1); }); m_cv.wait(lock.m_lock, [this, &fn]() MP_REQUIRES(m_mutex) { return m_post_fn != &fn; }); } diff --git a/src/mp/util.cpp b/src/mp/util.cpp index 463947b9..45220a80 100644 --- a/src/mp/util.cpp +++ b/src/mp/util.cpp @@ -10,20 +10,27 @@ #include #include #include +#include #include #include #include #include -#include -#include -#include -#include #include #include // NOLINT(misc-include-cleaner) // IWYU pragma: keep #include #include #include +#ifdef WIN32 +#include +#include +#else +#include +#include +#include +#include +#endif + #ifdef __linux__ #include #endif @@ -34,6 +41,11 @@ namespace fs = std::filesystem; +#ifdef WIN32 +// Forward-declare internal capnp function. +namespace kj { namespace _ { int win32Socketpair(SOCKET socks[2]); } } +#endif + namespace mp { namespace { @@ -48,6 +60,7 @@ std::vector MakeArgv(const std::vector& args) return argv; } +#ifndef WIN32 //! Return highest possible file descriptor. size_t MaxFd() { @@ -58,6 +71,7 @@ size_t MaxFd() return 1023; } } +#endif } // namespace @@ -79,6 +93,8 @@ std::string ThreadName(const char* exe_name) // the former are shorter and are the same as what gdb prints "LWP ...". #ifdef __linux__ buffer << syscall(SYS_gettid); +#elif defined(WIN32) + buffer << GetCurrentThreadId(); #elif defined(HAVE_PTHREAD_THREADID_NP) uint64_t tid = 0; pthread_threadid_np(nullptr, &tid); @@ -116,59 +132,147 @@ std::string LogEscape(const kj::StringTree& string, size_t max_size) return result; } -int SpawnProcess(int& pid, FdToArgsFn&& fd_to_args) +std::array SocketPair() { - int fds[2]; - if (socketpair(AF_UNIX, SOCK_STREAM, 0, fds) != 0) { - throw std::system_error(errno, std::system_category(), "socketpair"); +#ifdef WIN32 + SOCKET pair[2]; + KJ_WINSOCK(kj::_::win32Socketpair(pair)); +#else + int pair[2]; + KJ_SYSCALL(socketpair(AF_UNIX, SOCK_STREAM, 0, pair)); +#endif + return {pair[0], pair[1]}; +} + +//! Generate command line that the executable being invoked will split up using +//! the CommandLineToArgvW function, which expects arguments with spaces to be +//! quoted, quote characters to be backslash-escaped, and backslashes to also be +//! backslash-escaped, but only if they precede a quote character. +std::string CommandLineFromArgv(const std::vector& argv) +{ + std::string out; + for (const auto& arg : argv) { + if (!out.empty()) out += " "; + if (!arg.empty() && arg.find_first_of(" \t\"") == std::string::npos) { + // Argument has no quotes or spaces so escaping not necessary. + out += arg; + } else { + out += '"'; // Start with a quote + for (size_t i = 0; i < arg.size(); ++i) { + if (arg[i] == '\\') { + // Count consecutive backslashes + size_t backslash_count = 0; + while (i < arg.size() && arg[i] == '\\') { + ++backslash_count; + ++i; + } + if (i < arg.size() && arg[i] == '"') { + // Backslashes before a quote need to be doubled + out.append(backslash_count * 2 + 1, '\\'); + out.push_back('"'); + } else { + // Otherwise, backslashes remain as-is + out.append(backslash_count, '\\'); + --i; // Compensate for the outer loop's increment + } + } else if (arg[i] == '"') { + // Escape double quotes with a backslash + out.push_back('\\'); + out.push_back('"'); + } else { + out.push_back(arg[i]); + } + } + out += '"'; // End with a quote + } } + return out; +} +ProcessId SpawnProcess(SocketId socket, ConnectInfoToArgsFn&& connect_info_to_args) +{ +#ifndef WIN32 // Evaluate the callback and build the argv array before forking. // // The parent process may be multi-threaded and holding internal library // locks at fork time. In that case, running code that allocates memory or // takes locks in the child between fork() and exec() can deadlock // indefinitely. Precomputing arguments in the parent avoids this. - const std::vector args{fd_to_args(fds[0])}; - const std::vector argv{MakeArgv(args)}; + const std::vector args{connect_info_to_args(std::to_string(socket))}; - pid = fork(); + int pid{fork()}; if (pid == -1) { throw std::system_error(errno, std::system_category(), "fork"); } - // Parent process closes the descriptor for socket 0, child closes the - // descriptor for socket 1. On failure, the parent throws, but the child - // must _exit(126) (post-fork child must not throw). - if (close(fds[pid ? 0 : 1]) != 0) { - if (pid) { - (void)close(fds[1]); - throw std::system_error(errno, std::system_category(), "close"); - } - static constexpr char msg[] = "SpawnProcess(child): close(fds[1]) failed\n"; - const ssize_t writeResult = ::write(STDERR_FILENO, msg, sizeof(msg) - 1); - (void)writeResult; - _exit(126); - } - if (!pid) { // Child process must close all potentially open descriptors, except // socket 0. Do not throw, allocate, or do non-fork-safe work here. const int maxFd = MaxFd(); for (int fd = 3; fd < maxFd; ++fd) { - if (fd != fds[0]) { + if (fd != socket) { close(fd); } } - execvp(argv[0], argv.data()); - // NOTE: perror() is not async-signal-safe; calling it here in a - // post-fork child may deadlock in multithreaded parents. - // TODO: Report errors to the parent via a pipe (e.g. write errno) - // so callers can get diagnostics without relying on perror(). - perror("execvp failed"); - _exit(127); + int flags = fcntl(socket, F_GETFD); + if (flags == -1) throw std::system_error(errno, std::system_category(), "fcntl F_GETFD"); + if (flags & FD_CLOEXEC) { + flags &= ~FD_CLOEXEC; + if (fcntl(socket, F_SETFD, flags) == -1) throw std::system_error(errno, std::system_category(), "fcntl F_SETFD"); + } + + ExecProcess(args); } - return fds[1]; + return pid; +#else + // Create windows pipe to send pipe.ends[0] over to child process. + static std::atomic counter{1}; + ConnectInfo pipe_path{"\\\\.\\pipe\\mp-" + std::to_string(GetCurrentProcessId()) + "-" + std::to_string(counter.fetch_add(1))}; + HANDLE pipe{CreateNamedPipeA(pipe_path.c_str(), PIPE_ACCESS_OUTBOUND, PIPE_TYPE_MESSAGE | PIPE_WAIT, 1, 0, 0, 0, nullptr)}; + KJ_WIN32(pipe != INVALID_HANDLE_VALUE, "CreateNamedPipe failed"); + + // Start child process + std::string cmd{CommandLineFromArgv(connect_info_to_args(pipe_path))}; + STARTUPINFOA si{}; + si.cb = sizeof(si); + PROCESS_INFORMATION pi{}; + KJ_WIN32(CreateProcessA(nullptr, const_cast(cmd.c_str()), nullptr, nullptr, TRUE, 0, nullptr, nullptr, &si, &pi), "CreateProcess failed"); + CloseHandle(pi.hThread); // not needed + + // Duplicate socket for the child (now that we know its PID) + WSAPROTOCOL_INFO info{}; + KJ_WINSOCK(WSADuplicateSocket(socket, pi.dwProcessId, &info), "WSADuplicateSocket failed"); + + // Send socket to the child via the pipe + KJ_WIN32(ConnectNamedPipe(pipe, nullptr) || GetLastError() == ERROR_PIPE_CONNECTED, "ConnectNamedPipe failed"); + DWORD wr; + KJ_WIN32(WriteFile(pipe, &info, sizeof(info), &wr, nullptr) && wr == sizeof(info), "WriteFile(pipe) failed"); + CloseHandle(pipe); + + return reinterpret_cast(pi.hProcess); +#endif +} + +SocketId StartSpawned(const ConnectInfo& connect_info) +{ +#ifndef WIN32 + return std::stoi(connect_info); +#else + HANDLE pipe = CreateFileA(connect_info.c_str(), GENERIC_READ, 0, nullptr, OPEN_EXISTING, 0, nullptr); + KJ_WIN32(pipe != INVALID_HANDLE_VALUE, "CreateFile(pipe) failed"); + + WSAPROTOCOL_INFO info{}; + DWORD rd; + KJ_WIN32(ReadFile(pipe, &info, sizeof(info), &rd, nullptr) && rd == sizeof(info), "ReadFile(pipe) failed"); + CloseHandle(pipe); + + WSADATA dontcare; + KJ_WIN32(WSAStartup(MAKEWORD(2, 2), &dontcare) != 0, "WSAStartup() failed"); + + SOCKET socket{WSASocketA(FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO, &info, 0, WSA_FLAG_OVERLAPPED | WSA_FLAG_NO_HANDLE_INHERIT)}; + KJ_WINSOCK(socket, "WSASocket(FROM_PROTOCOL_INFO) failed"); + return socket; +#endif } void ExecProcess(const std::vector& args) @@ -183,13 +287,22 @@ void ExecProcess(const std::vector& args) } } -int WaitProcess(int pid) +int WaitProcess(ProcessId pid) { +#ifndef WIN32 int status; if (::waitpid(pid, &status, /*options=*/0) != pid) { throw std::system_error(errno, std::system_category(), "waitpid"); } return status; +#else + HANDLE handle{reinterpret_cast(pid)}; + DWORD result{WaitForSingleObject(handle, INFINITE)}; + KJ_WIN32(result != WAIT_OBJECT_0, "WaitForSingleObject(child) failed"); + KJ_WIN32(GetExitCodeProcess(handle, &result), "GetExitCodeProcess failed"); + CloseHandle(handle); + return result; +#endif } } // namespace mp diff --git a/test/mp/test/spawn_tests.cpp b/test/mp/test/spawn_tests.cpp index a14e50e2..60a6d090 100644 --- a/test/mp/test/spawn_tests.cpp +++ b/test/mp/test/spawn_tests.cpp @@ -6,6 +6,7 @@ #include +#include #include #include #include @@ -86,14 +87,15 @@ KJ_TEST("SpawnProcess does not run callback in child") control_cv.notify_one(); }); - int pid{-1}; - const int fd{mp::SpawnProcess(pid, [&](int child_fd) -> std::vector { + auto [parent_fd, child_fd] = mp::SocketPair(); + const mp::ProcessId pid{mp::SpawnProcess(child_fd, [&](const mp::ConnectInfo& connect_info) -> std::vector { // If this callback runs in the post-fork child, target_mutex appears // locked forever (the owning thread does not exist), so this deadlocks. std::lock_guard g(target_mutex); - return {"true", std::to_string(child_fd)}; + return {"true", connect_info}; })}; - ::close(fd); + ::close(child_fd); + ::close(parent_fd); int status{0}; // Give the child some time to exit. If it does not, terminate it and