Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions example/calculator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
#include <init.capnp.h>
#include <init.capnp.proxy.h> // NOLINT(misc-include-cleaner) // IWYU pragma: keep

#include <charconv>
#include <cstring>
#include <cstring> // IWYU pragma: keep
#include <fstream>
#include <functional>
#include <iostream>
#include <kj/async.h>
#include <kj/async-io.h>
#include <kj/common.h>
#include <kj/memory.h>
#include <memory>
#include <mp/proxy-io.h>
#include <mp/util.h>
#include <stdexcept>
#include <string>
#include <system_error>
#include <utility>

class CalculatorImpl : public Calculator
Expand Down Expand Up @@ -51,14 +51,11 @@ int main(int argc, char** argv)
std::cout << "Usage: mpcalculator <fd>\n";
return 1;
}
int fd;
if (std::from_chars(argv[1], argv[1] + strlen(argv[1]), fd).ec != std::errc{}) {
std::cerr << argv[1] << " is not a number or is larger than an int\n";
return 1;
}
mp::SocketId socket{mp::StartSpawned(argv[1])};
mp::EventLoop loop("mpcalculator", LogPrint);
std::unique_ptr<Init> init = std::make_unique<InitImpl>();
mp::ServeStream<InitInterface>(loop, fd, *init);
mp::Stream stream{loop.m_io_context.lowLevelProvider->wrapSocketFd(socket, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP)};
mp::ServeStream<InitInterface>(loop, kj::mv(stream), *init);
loop.loop();
return 0;
}
14 changes: 9 additions & 5 deletions example/example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,38 @@
#include <init.capnp.h>
#include <init.capnp.proxy.h>

#include <array>
#include <cstring> // IWYU pragma: keep
#include <filesystem>
#include <fstream>
#include <future>
#include <iostream>
#include <kj/async.h>
#include <kj/async-io.h>
#include <kj/common.h>
#include <kj/memory.h>
#include <memory>
#include <mp/proxy-io.h>
#include <mp/util.h>
#include <stdexcept>
#include <string>
#include <thread>
#include <tuple>
#include <utility>
#include <vector>

namespace fs = std::filesystem;

static auto Spawn(mp::EventLoop& loop, const std::string& process_argv0, const std::string& new_exe_name)
{
int pid;
const int fd = mp::SpawnProcess(pid, [&](int fd) -> std::vector<std::string> {
auto pair{mp::SocketPair()};
mp::ProcessId pid{mp::SpawnProcess(pair[0], [&](mp::ConnectInfo info) -> std::vector<std::string> {
fs::path path = process_argv0;
path.remove_filename();
path.append(new_exe_name);
return {path.string(), std::to_string(fd)};
});
return std::make_tuple(mp::ConnectStream<InitInterface>(loop, fd), pid);
return {path.string(), std::move(info)};
})};
return std::make_tuple(mp::ConnectStream<InitInterface>(loop, loop.m_io_context.lowLevelProvider->wrapSocketFd(pair[1])), pid);
}

static void LogPrint(mp::LogMessage log_data)
Expand Down
16 changes: 7 additions & 9 deletions example/printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@
#include <init.capnp.h>
#include <init.capnp.proxy.h> // NOLINT(misc-include-cleaner) // IWYU pragma: keep

#include <charconv>
#include <cstring>
#include <cstring> // IWYU pragma: keep
#include <fstream>
#include <iostream>
#include <kj/async.h>
#include <kj/async-io.h>
#include <kj/common.h>
#include <kj/memory.h>
#include <memory>
#include <mp/proxy-io.h>
#include <mp/util.h>
#include <stdexcept>
#include <string>
#include <system_error>
#include <utility>

class PrinterImpl : public Printer
{
Expand All @@ -44,14 +45,11 @@ int main(int argc, char** argv)
std::cout << "Usage: mpprinter <fd>\n";
return 1;
}
int fd;
if (std::from_chars(argv[1], argv[1] + strlen(argv[1]), fd).ec != std::errc{}) {
std::cerr << argv[1] << " is not a number or is larger than an int\n";
return 1;
}
mp::SocketId socket{mp::StartSpawned(argv[1])};
mp::EventLoop loop("mpprinter", LogPrint);
std::unique_ptr<Init> init = std::make_unique<InitImpl>();
mp::ServeStream<InitInterface>(loop, fd, *init);
mp::Stream stream{loop.m_io_context.lowLevelProvider->wrapSocketFd(socket, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP)};
mp::ServeStream<InitInterface>(loop, std::move(stream), *init);
loop.loop();
return 0;
}
29 changes: 19 additions & 10 deletions include/mp/proxy-io.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,17 @@ class Logger

std::string LongThreadName(const char* exe_name);

using Stream = kj::Own<kj::AsyncIoStream>;

inline SocketId StreamSocketId(const Stream& stream)
{
if (stream) KJ_IF_MAYBE(fd, stream->getFd()) return *fd;
#ifdef WIN32
if (stream) KJ_IF_MAYBE(handle, stream->getWin32Handle()) return reinterpret_cast<SocketId>(*handle);
#endif
throw std::logic_error("Stream socket unset");
}

//! Event loop implementation.
//!
//! Cap'n Proto threading model is very simple: all I/O operations are
Expand Down Expand Up @@ -308,11 +319,12 @@ class EventLoop
//! Callback functions to run on async thread.
std::optional<CleanupList> m_async_fns MP_GUARDED_BY(m_mutex);

//! Pipe read handle used to wake up the event loop thread.
int m_wait_fd = -1;
//! Socket pair used to post and wait for wakeups to the event loop thread.
kj::Own<kj::AsyncIoStream> m_wait_stream;
kj::Own<kj::AsyncIoStream> m_post_stream;

//! Pipe write handle used to wake up the event loop thread.
int m_post_fd = -1;
//! Synchronous writer used to write to m_post_stream.
kj::Own<kj::OutputStream> m_post_writer;

//! Number of clients holding references to ProxyServerBase objects that
//! reference this event loop.
Expand Down Expand Up @@ -797,13 +809,11 @@ kj::Promise<T> ProxyServer<Thread>::post(Fn&& fn)
//! over the stream. Also create a new Connection object embedded in the
//! client that is freed when the client is closed.
template <typename InitInterface>
std::unique_ptr<ProxyClient<InitInterface>> ConnectStream(EventLoop& loop, int fd)
std::unique_ptr<ProxyClient<InitInterface>> ConnectStream(EventLoop& loop, kj::Own<kj::AsyncIoStream> stream)
{
typename InitInterface::Client init_client(nullptr);
std::unique_ptr<Connection> connection;
loop.sync([&] {
auto stream =
loop.m_io_context.lowLevelProvider->wrapSocketFd(fd, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP);
connection = std::make_unique<Connection>(loop, kj::mv(stream));
init_client = connection->m_rpc_system->bootstrap(ServerVatId().vat_id).castAs<InitInterface>();
Connection* connection_ptr = connection.get();
Expand Down Expand Up @@ -854,10 +864,9 @@ void _Listen(EventLoop& loop, kj::Own<kj::ConnectionReceiver>&& listener, InitIm
//! Given stream file descriptor and an init object, handle requests on the
//! stream by calling methods on the Init object.
template <typename InitInterface, typename InitImpl>
void ServeStream(EventLoop& loop, int fd, InitImpl& init)
void ServeStream(EventLoop& loop, kj::Own<kj::AsyncIoStream> stream, InitImpl& init)
{
_Serve<InitInterface>(
loop, loop.m_io_context.lowLevelProvider->wrapSocketFd(fd, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP), init);
_Serve<InitInterface>(loop, kj::mv(stream), init);
}

//! Given listening socket file descriptor and an init object, handle incoming
Expand Down
46 changes: 36 additions & 10 deletions include/mp/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#ifndef MP_UTIL_H
#define MP_UTIL_H

#include <array>
#include <capnp/schema.h>
#include <cassert>
#include <cstddef>
Expand All @@ -20,6 +21,10 @@
#include <variant>
#include <vector>

#ifdef WIN32
#include <winsock2.h>
#endif

namespace mp {

//! Generic utility functions used by capnp code.
Expand Down Expand Up @@ -249,25 +254,46 @@ std::string ThreadName(const char* exe_name);
//! errors in python unit tests.
std::string LogEscape(const kj::StringTree& string, size_t max_size);

#ifdef WIN32
using ProcessId = uintptr_t;
using SocketId = uintptr_t;
constexpr SocketId SocketError{INVALID_SOCKET};
#else
using ProcessId = int;
using SocketId = int;
constexpr SocketId SocketError{-1};
#endif

//! Information about parent process passed to child process. On unix this is
//! just the inherited int file descriptor formatted as a string. On windows,
//! this is a path to a named pipe the parent process will write
//! WSADuplicateSocket info to.
using ConnectInfo = std::string;

//! Callback type used by SpawnProcess below.
using FdToArgsFn = std::function<std::vector<std::string>(int fd)>;
using ConnectInfoToArgsFn = std::function<std::vector<std::string>(const ConnectInfo&)>;

//! Create a socket pair that can be used to communicate within a process or
//! between parent and child processes.
std::array<SocketId, 2> SocketPair();

//! Spawn a new process that communicates with the current process over provided
//! socket argument. Calls connect_info_to_args callback with a connection
//! string that needs to be passed to the child process, and executes the
//! argv command line it returns. Returns child process id.
ProcessId SpawnProcess(SocketId socket, ConnectInfoToArgsFn&& connect_info_to_args);

//! Spawn a new process that communicates with the current process over a socket
//! pair. Returns pid through an output argument, and file descriptor for the
//! local side of the socket.
//! The fd_to_args callback is invoked in the parent process before fork().
//! It must not rely on child pid/state, and must return the command line
//! arguments that should be used to execute the process. Embed the remote file
//! descriptor number in whatever format the child process expects.
int SpawnProcess(int& pid, FdToArgsFn&& fd_to_args);
//! Initialize spawned child process using the ConnectInfo string passed to it,
//! returning a socket id for communicating with the parent process.
SocketId StartSpawned(const ConnectInfo& connect_info);

//! Call execvp with vector args.
//! Not safe to call in a post-fork child of a multi-threaded process.
//! Currently only used by mpgen at build time.
void ExecProcess(const std::vector<std::string>& args);

//! Wait for a process to exit and return its exit code.
int WaitProcess(int pid);
int WaitProcess(ProcessId pid);

inline char* CharCast(char* c) { return c; }
inline char* CharCast(unsigned char* c) { return (char*)c; }
Expand Down
Loading
Loading