diff --git a/lib/uffd/hotpages.go b/lib/uffd/hotpages.go new file mode 100644 index 00000000..25eecad9 --- /dev/null +++ b/lib/uffd/hotpages.go @@ -0,0 +1,176 @@ +package uffd + +import ( + "bufio" + "encoding/binary" + "errors" + "fmt" + "os" + "path/filepath" + "sort" + "sync" +) + +// HotPage points at a single page-aligned location inside a registered +// memory region. Region is the index into the handshake's mappings list; +// PageOffset is the byte offset of the page within that region (always +// a multiple of the server's page size). +type HotPage struct { + Region uint32 + PageOffset uint64 +} + +// HotPageList is the persisted "what pages should we eagerly populate +// before the guest unpauses" list. PR 8 records one of these during a +// template's first fork warm-up and bakes it into Template.HotPagesPath; +// later forks call Server.Prefetch with the loaded list to skip the +// fault round-trips on those pages. +// +// Concurrent Add/Snapshot is safe; Save and Load are not — callers +// generally Save once at the end of warmup and Load once at boot. +type HotPageList struct { + mu sync.Mutex + pages []HotPage +} + +// hotPagesFileMagic prefixes saved files so we can refuse to load +// arbitrary garbage. The version byte exists so a future format change +// can be rejected loudly instead of silently misinterpreted. +var hotPagesFileMagic = []byte("HPL1") + +// Add records a single hot page. Duplicates are tolerated; Snapshot +// dedups before returning. +func (h *HotPageList) Add(p HotPage) { + h.mu.Lock() + h.pages = append(h.pages, p) + h.mu.Unlock() +} + +// Len returns the number of recorded pages (with duplicates). +func (h *HotPageList) Len() int { + h.mu.Lock() + defer h.mu.Unlock() + return len(h.pages) +} + +// Snapshot returns a sorted, deduplicated copy of the recorded pages. +// Sort order is (Region, PageOffset) so prefetch issues sequential +// reads against the template mem-file. +func (h *HotPageList) Snapshot() []HotPage { + h.mu.Lock() + src := make([]HotPage, len(h.pages)) + copy(src, h.pages) + h.mu.Unlock() + + sort.Slice(src, func(i, j int) bool { + if src[i].Region != src[j].Region { + return src[i].Region < src[j].Region + } + return src[i].PageOffset < src[j].PageOffset + }) + out := src[:0] + var last HotPage + for i, p := range src { + if i == 0 || p != last { + out = append(out, p) + last = p + } + } + return out +} + +// Save atomically writes the deduplicated snapshot to path. The format +// is: 4-byte magic ("HPL1"), uvarint count, then for each page a +// uvarint region index and a uvarint page offset. Atomic via tmp+rename. +func (h *HotPageList) Save(path string) error { + pages := h.Snapshot() + tmp := path + ".tmp" + f, err := os.Create(tmp) + if err != nil { + return fmt.Errorf("uffd: create hot pages tmp: %w", err) + } + bw := bufio.NewWriter(f) + if _, err := bw.Write(hotPagesFileMagic); err != nil { + _ = f.Close() + _ = os.Remove(tmp) + return fmt.Errorf("uffd: write hot pages magic: %w", err) + } + var ibuf [binary.MaxVarintLen64]byte + n := binary.PutUvarint(ibuf[:], uint64(len(pages))) + if _, err := bw.Write(ibuf[:n]); err != nil { + _ = f.Close() + _ = os.Remove(tmp) + return fmt.Errorf("uffd: write hot pages count: %w", err) + } + for _, p := range pages { + n = binary.PutUvarint(ibuf[:], uint64(p.Region)) + if _, err := bw.Write(ibuf[:n]); err != nil { + _ = f.Close() + _ = os.Remove(tmp) + return fmt.Errorf("uffd: write hot pages region: %w", err) + } + n = binary.PutUvarint(ibuf[:], p.PageOffset) + if _, err := bw.Write(ibuf[:n]); err != nil { + _ = f.Close() + _ = os.Remove(tmp) + return fmt.Errorf("uffd: write hot pages offset: %w", err) + } + } + if err := bw.Flush(); err != nil { + _ = f.Close() + _ = os.Remove(tmp) + return fmt.Errorf("uffd: flush hot pages: %w", err) + } + if err := f.Close(); err != nil { + _ = os.Remove(tmp) + return fmt.Errorf("uffd: close hot pages tmp: %w", err) + } + if err := os.Rename(tmp, path); err != nil { + return fmt.Errorf("uffd: rename hot pages: %w", err) + } + return nil +} + +// LoadHotPageList reads a HotPageList from path. Returns an empty list +// (not an error) when path does not exist; the absence of a baked +// hot-page file simply means "don't prefetch." +func LoadHotPageList(path string) (*HotPageList, error) { + clean := filepath.Clean(path) + data, err := os.ReadFile(clean) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return &HotPageList{}, nil + } + return nil, fmt.Errorf("uffd: read hot pages: %w", err) + } + if len(data) < len(hotPagesFileMagic) { + return nil, errors.New("uffd: hot pages file truncated") + } + if string(data[:len(hotPagesFileMagic)]) != string(hotPagesFileMagic) { + return nil, errors.New("uffd: hot pages file has bad magic") + } + rest := data[len(hotPagesFileMagic):] + count, n := binary.Uvarint(rest) + if n <= 0 { + return nil, errors.New("uffd: hot pages file has bad count") + } + rest = rest[n:] + out := &HotPageList{pages: make([]HotPage, 0, count)} + for i := uint64(0); i < count; i++ { + region, n := binary.Uvarint(rest) + if n <= 0 { + return nil, fmt.Errorf("uffd: hot pages file truncated at entry %d (region)", i) + } + rest = rest[n:] + offset, n := binary.Uvarint(rest) + if n <= 0 { + return nil, fmt.Errorf("uffd: hot pages file truncated at entry %d (offset)", i) + } + rest = rest[n:] + out.pages = append(out.pages, HotPage{Region: uint32(region), PageOffset: offset}) + } + if len(rest) != 0 { + return nil, fmt.Errorf("uffd: hot pages file has %d trailing bytes", len(rest)) + } + return out, nil +} diff --git a/lib/uffd/hotpages_test.go b/lib/uffd/hotpages_test.go new file mode 100644 index 00000000..4441194f --- /dev/null +++ b/lib/uffd/hotpages_test.go @@ -0,0 +1,66 @@ +package uffd + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHotPageList_SnapshotSortsAndDedups(t *testing.T) { + var l HotPageList + l.Add(HotPage{Region: 1, PageOffset: 8192}) + l.Add(HotPage{Region: 0, PageOffset: 4096}) + l.Add(HotPage{Region: 0, PageOffset: 4096}) // duplicate + l.Add(HotPage{Region: 0, PageOffset: 0}) + + got := l.Snapshot() + want := []HotPage{ + {Region: 0, PageOffset: 0}, + {Region: 0, PageOffset: 4096}, + {Region: 1, PageOffset: 8192}, + } + assert.Equal(t, want, got) +} + +func TestHotPageList_SaveLoadRoundTrip(t *testing.T) { + var l HotPageList + l.Add(HotPage{Region: 0, PageOffset: 0}) + l.Add(HotPage{Region: 0, PageOffset: 4096}) + l.Add(HotPage{Region: 2, PageOffset: 1 << 20}) + + path := filepath.Join(t.TempDir(), "hot.bin") + require.NoError(t, l.Save(path)) + + got, err := LoadHotPageList(path) + require.NoError(t, err) + assert.Equal(t, l.Snapshot(), got.Snapshot()) +} + +func TestLoadHotPageList_MissingReturnsEmpty(t *testing.T) { + got, err := LoadHotPageList(filepath.Join(t.TempDir(), "absent.bin")) + require.NoError(t, err) + assert.Equal(t, 0, got.Len()) +} + +func TestLoadHotPageList_BadMagic(t *testing.T) { + path := filepath.Join(t.TempDir(), "bad.bin") + require.NoError(t, writeFile(path, []byte("XXXX\x00"))) + _, err := LoadHotPageList(path) + assert.Error(t, err) +} + +func TestLoadHotPageList_TruncatedAtEntry(t *testing.T) { + path := filepath.Join(t.TempDir(), "trunc.bin") + // magic + count=2 + only one entry + data := append([]byte("HPL1"), 0x02, 0x00, 0x00) // count=2, region=0, offset=0 + require.NoError(t, writeFile(path, data)) + _, err := LoadHotPageList(path) + assert.Error(t, err) +} + +func writeFile(path string, data []byte) error { + return os.WriteFile(path, data, 0o600) +} diff --git a/lib/uffd/server_linux.go b/lib/uffd/server_linux.go index 616ad3ab..618addb4 100644 --- a/lib/uffd/server_linux.go +++ b/lib/uffd/server_linux.go @@ -145,6 +145,10 @@ func (s *Server) startListener(ctx context.Context, forkID string, socketPath st } } + s.installPrefetcher(forkID, func(list *HotPageList) error { + return s.prefetchInto(fd, regions, list) + }) + s.servePageFaults(hctx, fd, regions, forkID) }() @@ -272,12 +276,13 @@ func (s *Server) copyPageForFault(fd int, regions []MemoryRegion, addr uint64, p pageSize := uint64(s.pageSize) pageStart := addr &^ (pageSize - 1) - for _, r := range regions { + for idx, r := range regions { base := uint64(r.BaseHostAddr) if pageStart < base || pageStart >= base+r.Size { continue } - offset := int64(r.MemFileOffset + (pageStart - base)) + regionOff := pageStart - base + offset := int64(r.MemFileOffset + regionOff) if _, err := s.memFile.ReadAt(page, offset); err != nil && !errors.Is(err, io.EOF) { return fmt.Errorf("uffd: read template at %d: %w", offset, err) } @@ -294,11 +299,56 @@ func (s *Server) copyPageForFault(fd int, regions []MemoryRegion, addr uint64, p } return fmt.Errorf("uffd: UFFDIO_COPY: %w", err) } + if s.cfg.RecordHotPages { + s.hotPages.Add(HotPage{Region: uint32(idx), PageOffset: regionOff}) + } return nil } return fmt.Errorf("uffd: fault addr 0x%x outside any registered region", addr) } +// prefetchInto walks list and issues a UFFDIO_COPY for each entry +// against the supplied fork's userfaultfd. It tolerates EEXIST/EAGAIN +// the same way the fault handler does so a racing first-touch fault +// from a vCPU does not abort the whole prefetch. +func (s *Server) prefetchInto(fd int, regions []MemoryRegion, list *HotPageList) error { + if list == nil { + return nil + } + pages := list.Snapshot() + if len(pages) == 0 { + return nil + } + page := make([]byte, s.pageSize) + pageSize := uint64(s.pageSize) + for _, hp := range pages { + if int(hp.Region) >= len(regions) { + return fmt.Errorf("uffd: prefetch entry refers to region %d (only %d registered)", hp.Region, len(regions)) + } + r := regions[hp.Region] + if hp.PageOffset+pageSize > r.Size { + return fmt.Errorf("uffd: prefetch offset %d outside region %d size %d", hp.PageOffset, hp.Region, r.Size) + } + dst := uint64(r.BaseHostAddr) + hp.PageOffset + src := int64(r.MemFileOffset + hp.PageOffset) + if _, err := s.memFile.ReadAt(page, src); err != nil && !errors.Is(err, io.EOF) { + return fmt.Errorf("uffd: prefetch read template at %d: %w", src, err) + } + copyArg := uffdioCopyArg{ + Dst: dst, + Src: uint64(uintptr(unsafe.Pointer(&page[0]))), + Len: pageSize, + } + if err := ioctl(fd, uffdioCopyIoctl, unsafe.Pointer(©Arg)); err != nil { + if errors.Is(err, syscall.EEXIST) || errors.Is(err, syscall.EAGAIN) { + continue + } + return fmt.Errorf("uffd: prefetch UFFDIO_COPY: %w", err) + } + } + return nil +} + func ioctl(fd int, req uintptr, arg unsafe.Pointer) error { _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), req, uintptr(arg)) if errno != 0 { diff --git a/lib/uffd/uffd.go b/lib/uffd/uffd.go index 9531f02b..4a4cabd4 100644 --- a/lib/uffd/uffd.go +++ b/lib/uffd/uffd.go @@ -68,6 +68,12 @@ type Config struct { // PageSize is the target page size for UFFDIO_COPY. Must be a // multiple of os.Getpagesize. Zero means use the host page size. PageSize int + + // RecordHotPages turns on per-fault recording. Every successfully + // served page is appended to the server's hot-page list. Callers + // typically enable this during a template's first warmup fork, + // then HotPages().Save() the result before promoting the template. + RecordHotPages bool } // Server owns the template mem-file and dispatches userfaultfd events @@ -82,11 +88,19 @@ type Server struct { listens map[string]*forkListen // forkID -> per-fork bookkeeping closed bool pageSize int + + hotPages HotPageList // recorded faults; only used when cfg.RecordHotPages } type forkListen struct { socketPath string closer func() error + + // prefetch is set by the platform-specific listener once the uffd + // fd has been received and registered. Calling it issues UFFDIO_COPY + // for every entry in the supplied list against the fork's uffd. + // Nil means the fork hasn't connected yet. + prefetch func(*HotPageList) error } // NewServer opens the template mem-file and prepares the server. It @@ -147,6 +161,50 @@ func (s *Server) MemSize() int64 { return s.memSize } // PageSize returns the configured page size in bytes. func (s *Server) PageSize() int { return s.pageSize } +// HotPages returns the server's hot-page recorder. The returned value +// is the live list — Add/Snapshot/Save are all valid. Recording only +// happens when Config.RecordHotPages is set; callers may still inspect +// the (empty) list otherwise. +func (s *Server) HotPages() *HotPageList { return &s.hotPages } + +// Prefetch issues UFFDIO_COPY for every entry in list against the fork +// identified by forkID. Used to warm up known-hot pages before the +// guest unpauses, eliminating fault round-trips for anything we've +// pre-recorded. Returns an error if the fork is unknown or hasn't +// connected yet, or if the underlying ioctl fails (other than the +// benign EEXIST/EAGAIN race noted in copyPageForFault). +func (s *Server) Prefetch(forkID string, list *HotPageList) error { + if list == nil || list.Len() == 0 { + return nil + } + s.mu.Lock() + listen, ok := s.listens[forkID] + prefetch := func(*HotPageList) error { return nil } + if ok && listen.prefetch != nil { + prefetch = listen.prefetch + } + s.mu.Unlock() + if !ok { + return fmt.Errorf("uffd: fork %q is not registered", forkID) + } + if listen.prefetch == nil { + return fmt.Errorf("uffd: fork %q has not yet connected", forkID) + } + return prefetch(list) +} + +// installPrefetcher is called by the platform-specific listener once +// the uffd is ready. It is a no-op if the fork has been unregistered. +func (s *Server) installPrefetcher(forkID string, fn func(*HotPageList) error) { + s.mu.Lock() + defer s.mu.Unlock() + listen, ok := s.listens[forkID] + if !ok { + return + } + listen.prefetch = fn +} + // Close stops the server, closes all per-fork listeners, and releases // the template mem-file fd. After Close returns, the server cannot be // reused.