From 27082d3ee1775d42951d1e8c9c49392a912b646c Mon Sep 17 00:00:00 2001 From: sjmiller609 <7516283+sjmiller609@users.noreply.github.com> Date: Fri, 8 May 2026 13:36:09 +0000 Subject: [PATCH] uffd: serve firecracker page faults from a shared template mem-file Adds lib/uffd, a userfaultfd page server that backs many concurrent fan-out forks against a single read-only template mem-file instead of letting each fork mmap it privately. Firecracker connects to a per-fork UDS, hands us its userfaultfd via SCM_RIGHTS along with a JSON mappings handshake, and the server then services UFFD_EVENT_PAGEFAULT events with UFFDIO_COPY reads from the template. The Linux hot path lives behind a build tag; non-Linux builds return ErrUnsupported so callers can fall back to MAP_PRIVATE. Cross-platform tests cover the handshake parser and the server lifecycle. Co-Authored-By: Claude Opus 4.7 --- lib/uffd/server_linux.go | 308 +++++++++++++++++++++++++++++++++++++++ lib/uffd/server_other.go | 15 ++ lib/uffd/uffd.go | 253 ++++++++++++++++++++++++++++++++ lib/uffd/uffd_test.go | 127 ++++++++++++++++ 4 files changed, 703 insertions(+) create mode 100644 lib/uffd/server_linux.go create mode 100644 lib/uffd/server_other.go create mode 100644 lib/uffd/uffd.go create mode 100644 lib/uffd/uffd_test.go diff --git a/lib/uffd/server_linux.go b/lib/uffd/server_linux.go new file mode 100644 index 00000000..616ad3ab --- /dev/null +++ b/lib/uffd/server_linux.go @@ -0,0 +1,308 @@ +//go:build linux + +package uffd + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "os" + "sync" + "syscall" + "unsafe" + + "golang.org/x/sys/unix" +) + +// userfaultfd ioctl numbers and feature flags. The constants are derived +// from : _IOWR(0xAA, ...) with the size of each +// argument struct in bits 16–29. +const ( + uffdAPI = 0xAA + uffdAPIFeature = 0x0 // we only need missing-page faults; no extra features. + + uffdioAPI = 0xC018AA3F // _IOWR(0xAA, 0x3F, struct uffdio_api{24}) + uffdioRegister = 0xC020AA00 // _IOWR(0xAA, 0x00, struct uffdio_register{32}) + uffdioCopyIoctl = 0xC028AA03 // _IOWR(0xAA, 0x03, struct uffdio_copy{40}) + uffdioZeropage = 0xC020AA04 // _IOWR(0xAA, 0x04, struct uffdio_zeropage{32}) + uffdRegMissing = 1 << 0 + uffdEventPagefnt = 0x12 // UFFD_EVENT_PAGEFAULT +) + +// uffdMsg mirrors struct uffd_msg from . It is a +// 32-byte fixed-size record; we only consume the pagefault arm. +type uffdMsg struct { + Event uint8 + _ uint8 + _ uint16 + _ uint32 + Pagefault struct { + Flags uint64 + Address uint64 + Ptid uint32 + _ uint32 + } +} + +// uffdioAPIArg is struct uffdio_api. +type uffdioAPIArg struct { + API uint64 + Features uint64 + Ioctls uint64 +} + +// uffdioRegisterArg is struct uffdio_register. +type uffdioRegisterArg struct { + Start uint64 + Len uint64 + Mode uint64 + Ioctls uint64 +} + +// uffdioCopyArg is struct uffdio_copy. +type uffdioCopyArg struct { + Dst uint64 + Src uint64 + Len uint64 + Mode uint64 + Copy int64 +} + +// startListener opens the per-fork UDS, accepts firecracker's connection, +// receives the userfaultfd via SCM_RIGHTS plus the JSON handshake, and +// then runs the page-fault loop. The returned closer stops accept, +// signals the handler, and removes the socket file. +func (s *Server) startListener(ctx context.Context, forkID string, socketPath string) (func() error, error) { + // Remove any stale socket file from a prior run; UDS bind fails otherwise. + _ = os.Remove(socketPath) + ln, err := net.Listen("unix", socketPath) + if err != nil { + return nil, fmt.Errorf("uffd: listen %s: %w", socketPath, err) + } + + hctx, hcancel := context.WithCancel(ctx) + + var ( + wg sync.WaitGroup + mu sync.Mutex + uffdFd int = -1 + closed bool + ) + + closer := func() error { + mu.Lock() + if closed { + mu.Unlock() + wg.Wait() + return nil + } + closed = true + fd := uffdFd + uffdFd = -1 + mu.Unlock() + + hcancel() + _ = ln.Close() + if fd >= 0 { + _ = unix.Close(fd) + } + wg.Wait() + _ = os.Remove(socketPath) + return nil + } + + wg.Add(1) + go func() { + defer wg.Done() + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + + fd, regions, err := receiveHandshake(conn) + if err != nil { + return + } + mu.Lock() + if closed { + mu.Unlock() + _ = unix.Close(fd) + return + } + uffdFd = fd + mu.Unlock() + + if err := uffdAPIHandshake(fd); err != nil { + return + } + for _, r := range regions { + if err := uffdRegisterRegion(fd, r); err != nil { + return + } + } + + s.servePageFaults(hctx, fd, regions, forkID) + }() + + return closer, nil +} + +// receiveHandshake reads firecracker's JSON payload and the userfaultfd +// over a single recvmsg(2) call. Firecracker sends them together; if the +// kernel splits them across reads we loop until the fd arrives. +func receiveHandshake(conn net.Conn) (int, []MemoryRegion, error) { + uc, ok := conn.(*net.UnixConn) + if !ok { + return -1, nil, errors.New("uffd: connection is not a unix socket") + } + f, err := uc.File() + if err != nil { + return -1, nil, fmt.Errorf("uffd: get fd from unix conn: %w", err) + } + defer f.Close() + + // Read until we have the SCM_RIGHTS fd. The JSON body is small, so + // a 4 KiB buffer plus one OOB control message is plenty. + buf := make([]byte, 4096) + oob := make([]byte, unix.CmsgSpace(4)) + var ( + jsonBytes []byte + fd int = -1 + ) + for fd < 0 { + n, oobn, _, _, err := unix.Recvmsg(int(f.Fd()), buf, oob, 0) + if err != nil { + return -1, nil, fmt.Errorf("uffd: recvmsg: %w", err) + } + if n > 0 { + jsonBytes = append(jsonBytes, buf[:n]...) + } + if oobn > 0 { + scms, perr := unix.ParseSocketControlMessage(oob[:oobn]) + if perr != nil { + return -1, nil, fmt.Errorf("uffd: parse cmsg: %w", perr) + } + for _, scm := range scms { + fds, ferr := unix.ParseUnixRights(&scm) + if ferr != nil { + return -1, nil, fmt.Errorf("uffd: parse fds: %w", ferr) + } + if len(fds) > 0 { + fd = fds[0] + for _, extra := range fds[1:] { + _ = unix.Close(extra) + } + } + } + } + if n == 0 && oobn == 0 { + return -1, nil, io.ErrUnexpectedEOF + } + } + + hs, err := parseHandshake(jsonBytes) + if err != nil { + _ = unix.Close(fd) + return -1, nil, err + } + return fd, hs.Mappings, nil +} + +func uffdAPIHandshake(fd int) error { + api := uffdioAPIArg{API: uffdAPI, Features: uffdAPIFeature} + if err := ioctl(fd, uffdioAPI, unsafe.Pointer(&api)); err != nil { + return fmt.Errorf("uffd: UFFDIO_API: %w", err) + } + return nil +} + +func uffdRegisterRegion(fd int, r MemoryRegion) error { + reg := uffdioRegisterArg{ + Start: uint64(r.BaseHostAddr), + Len: r.Size, + Mode: uffdRegMissing, + } + if err := ioctl(fd, uffdioRegister, unsafe.Pointer(®)); err != nil { + return fmt.Errorf("uffd: UFFDIO_REGISTER: %w", err) + } + return nil +} + +// servePageFaults blocks reading uffd events on fd. For each +// UFFD_EVENT_PAGEFAULT we look up the region containing the faulting +// address, read a page from the template mem-file, and call UFFDIO_COPY +// to satisfy the fault. +func (s *Server) servePageFaults(ctx context.Context, fd int, regions []MemoryRegion, forkID string) { + page := make([]byte, s.pageSize) + var msg uffdMsg + msgSize := int(unsafe.Sizeof(msg)) + rawBuf := make([]byte, msgSize) + + for { + if ctx.Err() != nil { + return + } + n, err := unix.Read(fd, rawBuf) + if err != nil { + if errors.Is(err, syscall.EINTR) { + continue + } + return + } + if n != msgSize { + return + } + event := rawBuf[0] + if event != uffdEventPagefnt { + continue + } + // pagefault.address starts at offset 16 of uffd_msg. + addr := binary.LittleEndian.Uint64(rawBuf[16:24]) + if err := s.copyPageForFault(fd, regions, addr, page); err != nil { + return + } + } +} + +func (s *Server) copyPageForFault(fd int, regions []MemoryRegion, addr uint64, page []byte) error { + pageSize := uint64(s.pageSize) + pageStart := addr &^ (pageSize - 1) + + for _, r := range regions { + base := uint64(r.BaseHostAddr) + if pageStart < base || pageStart >= base+r.Size { + continue + } + offset := int64(r.MemFileOffset + (pageStart - base)) + if _, err := s.memFile.ReadAt(page, offset); err != nil && !errors.Is(err, io.EOF) { + return fmt.Errorf("uffd: read template at %d: %w", offset, err) + } + copyArg := uffdioCopyArg{ + Dst: pageStart, + Src: uint64(uintptr(unsafe.Pointer(&page[0]))), + Len: pageSize, + } + if err := ioctl(fd, uffdioCopyIoctl, unsafe.Pointer(©Arg)); err != nil { + // Spurious/duplicate faults can race other vCPUs; treat + // them as benign and keep serving. + if errors.Is(err, syscall.EEXIST) || errors.Is(err, syscall.EAGAIN) { + return nil + } + return fmt.Errorf("uffd: UFFDIO_COPY: %w", err) + } + return nil + } + return fmt.Errorf("uffd: fault addr 0x%x outside any registered region", addr) +} + +func ioctl(fd int, req uintptr, arg unsafe.Pointer) error { + _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), req, uintptr(arg)) + if errno != 0 { + return errno + } + return nil +} diff --git a/lib/uffd/server_other.go b/lib/uffd/server_other.go new file mode 100644 index 00000000..7ce7b287 --- /dev/null +++ b/lib/uffd/server_other.go @@ -0,0 +1,15 @@ +//go:build !linux + +package uffd + +import "context" + +// startListener returns ErrUnsupported on non-Linux platforms. +// userfaultfd is a Linux-only kernel feature; callers should fall back +// to letting firecracker mmap the mem-file privately. +func (s *Server) startListener(ctx context.Context, forkID string, socketPath string) (func() error, error) { + _ = ctx + _ = forkID + _ = socketPath + return nil, ErrUnsupported +} diff --git a/lib/uffd/uffd.go b/lib/uffd/uffd.go new file mode 100644 index 00000000..9531f02b --- /dev/null +++ b/lib/uffd/uffd.go @@ -0,0 +1,253 @@ +// Package uffd implements a userfaultfd page server for firecracker +// snapshot fan-out. The server backs many concurrent forks against a +// single read-only template mem-file: instead of letting firecracker +// mmap the mem-file privately per fork (which forces every page to be +// copied on first touch), firecracker is configured to use a +// userfaultfd memory backend, and this server populates pages on +// demand from the template file. +// +// One Server instance handles one template mem-file and any number of +// fork connections. Each fork's firecracker process connects to a +// per-fork UDS and hands the server its userfaultfd via SCM_RIGHTS +// alongside a JSON payload describing the guest memory mappings; the +// server then handles UFFDIO_COPY for every faulted page. +// +// The protocol (firecracker_uffd_protocol below) is the contract +// firecracker speaks; we keep it isolated here so PR 8 can ride on +// top to prefetch hot pages without touching firecracker glue code. +// +// PR 5 ships the server skeleton, the protocol parser, and a unit +// test surface that doesn't require KVM. The hot-path syscalls live +// in server_linux.go behind a build tag because userfaultfd is a +// Linux-only kernel feature. +package uffd + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "sync" +) + +// ErrUnsupported is returned on platforms where userfaultfd is not +// available. Callers should treat this as "fall back to mmap MAP_PRIVATE." +var ErrUnsupported = errors.New("userfaultfd unsupported on this platform") + +// MemoryRegion describes a contiguous region of guest physical memory +// that maps into a [BaseHostAddr, BaseHostAddr+Size) virtual range in +// the firecracker process. The server services UFFDIO_COPY into that +// range using bytes from MemFileOffset. +type MemoryRegion struct { + BaseHostAddr uintptr `json:"base_host_virt_addr"` + Size uint64 `json:"size"` + MemFileOffset uint64 `json:"offset"` +} + +// firecrackerHandshake is the JSON payload firecracker sends on its +// UDS connection right before it passes the userfaultfd via SCM_RIGHTS. +// We only use the fields we care about for serving page faults; the +// rest of firecracker's payload is ignored. +type firecrackerHandshake struct { + Mappings []MemoryRegion `json:"mappings"` +} + +// Config configures a Server. +type Config struct { + // MemFilePath is the path to the template mem-file. The server + // opens it read-only and serves pages from it. + MemFilePath string + + // SocketDir is where per-fork UDS files live. The directory must + // exist and be writable by the server. One UDS is created per + // RegisterFork call. + SocketDir string + + // PageSize is the target page size for UFFDIO_COPY. Must be a + // multiple of os.Getpagesize. Zero means use the host page size. + PageSize int +} + +// Server owns the template mem-file and dispatches userfaultfd events +// for every connected fork. It is safe for concurrent use; methods may +// be called from any goroutine. +type Server struct { + cfg Config + memFile *os.File + memSize int64 + + mu sync.Mutex + listens map[string]*forkListen // forkID -> per-fork bookkeeping + closed bool + pageSize int +} + +type forkListen struct { + socketPath string + closer func() error +} + +// NewServer opens the template mem-file and prepares the server. It +// does not start any goroutines yet; callers register forks one by one. +// When the server is closed, the mem-file fd is released; in-flight +// fork handlers are signaled to exit and joined. +func NewServer(cfg Config) (*Server, error) { + if cfg.MemFilePath == "" { + return nil, errors.New("uffd: MemFilePath is required") + } + if cfg.SocketDir == "" { + return nil, errors.New("uffd: SocketDir is required") + } + if err := os.MkdirAll(cfg.SocketDir, 0o755); err != nil { + return nil, fmt.Errorf("uffd: ensure socket dir: %w", err) + } + f, err := os.Open(cfg.MemFilePath) + if err != nil { + return nil, fmt.Errorf("uffd: open mem-file: %w", err) + } + st, err := f.Stat() + if err != nil { + _ = f.Close() + return nil, fmt.Errorf("uffd: stat mem-file: %w", err) + } + pageSize := cfg.PageSize + if pageSize == 0 { + pageSize = os.Getpagesize() + } + return &Server{ + cfg: cfg, + memFile: f, + memSize: st.Size(), + listens: map[string]*forkListen{}, + pageSize: pageSize, + }, nil +} + +// SocketPath returns the UDS path that should be passed to firecracker +// for a fork. RegisterFork must be called first. +func (s *Server) SocketPath(forkID string) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.closed { + return "", errors.New("uffd: server closed") + } + listen, ok := s.listens[forkID] + if !ok { + return "", fmt.Errorf("uffd: fork %q is not registered", forkID) + } + return listen.socketPath, nil +} + +// MemSize returns the size of the template mem-file in bytes. Useful +// for sizing prefetch buffers and validating handshake mappings. +func (s *Server) MemSize() int64 { return s.memSize } + +// PageSize returns the configured page size in bytes. +func (s *Server) PageSize() int { return s.pageSize } + +// Close stops the server, closes all per-fork listeners, and releases +// the template mem-file fd. After Close returns, the server cannot be +// reused. +func (s *Server) Close() error { + s.mu.Lock() + if s.closed { + s.mu.Unlock() + return nil + } + s.closed = true + listens := s.listens + s.listens = nil + s.mu.Unlock() + + var firstErr error + for _, l := range listens { + if l.closer != nil { + if err := l.closer(); err != nil && firstErr == nil { + firstErr = err + } + } + } + if err := s.memFile.Close(); err != nil && firstErr == nil { + firstErr = err + } + return firstErr +} + +// parseHandshake decodes firecracker's JSON handshake payload. Exposed +// so tests can validate the parser without spinning up a real socket. +func parseHandshake(data []byte) (firecrackerHandshake, error) { + var h firecrackerHandshake + if err := json.Unmarshal(data, &h); err != nil { + return firecrackerHandshake{}, fmt.Errorf("uffd: parse handshake: %w", err) + } + if len(h.Mappings) == 0 { + return firecrackerHandshake{}, errors.New("uffd: handshake has no mappings") + } + return h, nil +} + +// resolveSocketPath returns the per-fork socket path. The server uses +// short names because Unix domain sockets have a tight sun_path limit; +// callers should keep SocketDir short. +func (s *Server) resolveSocketPath(forkID string) string { + return filepath.Join(s.cfg.SocketDir, forkID+".uffd") +} + +// RegisterFork allocates a per-fork listener and waits asynchronously +// for firecracker to connect. The returned context cancels when the +// server closes or the fork unregisters. +// +// On Linux the heavy lifting (accept, recvmsg, ioctl loop) lives in +// server_linux.go; on other platforms RegisterFork returns ErrUnsupported. +func (s *Server) RegisterFork(ctx context.Context, forkID string) (string, error) { + if forkID == "" { + return "", errors.New("uffd: fork id is required") + } + s.mu.Lock() + if s.closed { + s.mu.Unlock() + return "", errors.New("uffd: server closed") + } + if _, dup := s.listens[forkID]; dup { + s.mu.Unlock() + return "", fmt.Errorf("uffd: fork %q already registered", forkID) + } + socketPath := s.resolveSocketPath(forkID) + s.mu.Unlock() + + closer, err := s.startListener(ctx, forkID, socketPath) + if err != nil { + return "", err + } + + s.mu.Lock() + if s.closed { + s.mu.Unlock() + _ = closer() + return "", errors.New("uffd: server closed during register") + } + s.listens[forkID] = &forkListen{socketPath: socketPath, closer: closer} + s.mu.Unlock() + + return socketPath, nil +} + +// UnregisterFork closes the listener for forkID. Called when the fork +// is destroyed; the server stops servicing its faults and removes the +// UDS file. +func (s *Server) UnregisterFork(forkID string) error { + s.mu.Lock() + listen, ok := s.listens[forkID] + if !ok { + s.mu.Unlock() + return nil + } + delete(s.listens, forkID) + s.mu.Unlock() + if listen.closer != nil { + return listen.closer() + } + return nil +} diff --git a/lib/uffd/uffd_test.go b/lib/uffd/uffd_test.go new file mode 100644 index 00000000..c6b2fc3d --- /dev/null +++ b/lib/uffd/uffd_test.go @@ -0,0 +1,127 @@ +package uffd + +import ( + "errors" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func writeTempMemFile(t *testing.T, size int) string { + t.Helper() + path := filepath.Join(t.TempDir(), "memory") + f, err := os.Create(path) + require.NoError(t, err) + require.NoError(t, f.Truncate(int64(size))) + require.NoError(t, f.Close()) + return path +} + +func TestNewServer_RequiresMemFile(t *testing.T) { + _, err := NewServer(Config{SocketDir: t.TempDir()}) + assert.Error(t, err) +} + +func TestNewServer_RequiresSocketDir(t *testing.T) { + _, err := NewServer(Config{MemFilePath: writeTempMemFile(t, 4096)}) + assert.Error(t, err) +} + +func TestNewServer_ReportsMemSizeAndPageSize(t *testing.T) { + memPath := writeTempMemFile(t, 16384) + s, err := NewServer(Config{ + MemFilePath: memPath, + SocketDir: t.TempDir(), + PageSize: 4096, + }) + require.NoError(t, err) + defer s.Close() + + assert.Equal(t, int64(16384), s.MemSize()) + assert.Equal(t, 4096, s.PageSize()) +} + +func TestNewServer_DefaultsPageSizeToHost(t *testing.T) { + s, err := NewServer(Config{ + MemFilePath: writeTempMemFile(t, 4096), + SocketDir: t.TempDir(), + }) + require.NoError(t, err) + defer s.Close() + + assert.Equal(t, os.Getpagesize(), s.PageSize()) +} + +func TestSocketPath_UnregisteredFork(t *testing.T) { + s, err := NewServer(Config{ + MemFilePath: writeTempMemFile(t, 4096), + SocketDir: t.TempDir(), + }) + require.NoError(t, err) + defer s.Close() + + _, err = s.SocketPath("missing") + assert.Error(t, err) +} + +func TestUnregisterFork_UnknownIsNoop(t *testing.T) { + s, err := NewServer(Config{ + MemFilePath: writeTempMemFile(t, 4096), + SocketDir: t.TempDir(), + }) + require.NoError(t, err) + defer s.Close() + + assert.NoError(t, s.UnregisterFork("does-not-exist")) +} + +func TestClose_Idempotent(t *testing.T) { + s, err := NewServer(Config{ + MemFilePath: writeTempMemFile(t, 4096), + SocketDir: t.TempDir(), + }) + require.NoError(t, err) + require.NoError(t, s.Close()) + assert.NoError(t, s.Close()) +} + +func TestParseHandshake_GoodPayload(t *testing.T) { + data := []byte(`{"mappings":[{"base_host_virt_addr":4096,"size":8192,"offset":0}]}`) + hs, err := parseHandshake(data) + require.NoError(t, err) + require.Len(t, hs.Mappings, 1) + assert.Equal(t, uintptr(4096), hs.Mappings[0].BaseHostAddr) + assert.Equal(t, uint64(8192), hs.Mappings[0].Size) + assert.Equal(t, uint64(0), hs.Mappings[0].MemFileOffset) +} + +func TestParseHandshake_RejectsEmptyMappings(t *testing.T) { + _, err := parseHandshake([]byte(`{"mappings":[]}`)) + assert.Error(t, err) +} + +func TestParseHandshake_RejectsBadJSON(t *testing.T) { + _, err := parseHandshake([]byte(`{not json`)) + assert.Error(t, err) +} + +func TestResolveSocketPath_PerFork(t *testing.T) { + dir := t.TempDir() + s, err := NewServer(Config{ + MemFilePath: writeTempMemFile(t, 4096), + SocketDir: dir, + }) + require.NoError(t, err) + defer s.Close() + + got := s.resolveSocketPath("fork-1") + assert.Equal(t, filepath.Join(dir, "fork-1.uffd"), got) +} + +func TestErrUnsupportedSentinel(t *testing.T) { + // The sentinel must be a stable error value so callers can switch on it. + assert.True(t, errors.Is(ErrUnsupported, ErrUnsupported)) +}