Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions lib/uffd/hotpages.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
package uffd

import (
"bufio"
"encoding/binary"
"errors"
"fmt"
"os"
"path/filepath"
"sort"
"sync"
)

// HotPage points at a single page-aligned location inside a registered
// memory region. Region is the index into the handshake's mappings list;
// PageOffset is the byte offset of the page within that region (always
// a multiple of the server's page size).
type HotPage struct {
Region uint32
PageOffset uint64
}

// HotPageList is the persisted "what pages should we eagerly populate
// before the guest unpauses" list. PR 8 records one of these during a
// template's first fork warm-up and bakes it into Template.HotPagesPath;
// later forks call Server.Prefetch with the loaded list to skip the
// fault round-trips on those pages.
//
// Concurrent Add/Snapshot is safe; Save and Load are not — callers
// generally Save once at the end of warmup and Load once at boot.
type HotPageList struct {
mu sync.Mutex
pages []HotPage
}

// hotPagesFileMagic prefixes saved files so we can refuse to load
// arbitrary garbage. The version byte exists so a future format change
// can be rejected loudly instead of silently misinterpreted.
var hotPagesFileMagic = []byte("HPL1")

// Add records a single hot page. Duplicates are tolerated; Snapshot
// dedups before returning.
func (h *HotPageList) Add(p HotPage) {
h.mu.Lock()
h.pages = append(h.pages, p)
h.mu.Unlock()
}

// Len returns the number of recorded pages (with duplicates).
func (h *HotPageList) Len() int {
h.mu.Lock()
defer h.mu.Unlock()
return len(h.pages)
}

// Snapshot returns a sorted, deduplicated copy of the recorded pages.
// Sort order is (Region, PageOffset) so prefetch issues sequential
// reads against the template mem-file.
func (h *HotPageList) Snapshot() []HotPage {
h.mu.Lock()
src := make([]HotPage, len(h.pages))
copy(src, h.pages)
h.mu.Unlock()

sort.Slice(src, func(i, j int) bool {
if src[i].Region != src[j].Region {
return src[i].Region < src[j].Region
}
return src[i].PageOffset < src[j].PageOffset
})
out := src[:0]
var last HotPage
for i, p := range src {
if i == 0 || p != last {
out = append(out, p)
last = p
}
}
return out
}

// Save atomically writes the deduplicated snapshot to path. The format
// is: 4-byte magic ("HPL1"), uvarint count, then for each page a
// uvarint region index and a uvarint page offset. Atomic via tmp+rename.
func (h *HotPageList) Save(path string) error {
pages := h.Snapshot()
tmp := path + ".tmp"
f, err := os.Create(tmp)
if err != nil {
return fmt.Errorf("uffd: create hot pages tmp: %w", err)
}
bw := bufio.NewWriter(f)
if _, err := bw.Write(hotPagesFileMagic); err != nil {
_ = f.Close()
_ = os.Remove(tmp)
return fmt.Errorf("uffd: write hot pages magic: %w", err)
}
var ibuf [binary.MaxVarintLen64]byte
n := binary.PutUvarint(ibuf[:], uint64(len(pages)))
if _, err := bw.Write(ibuf[:n]); err != nil {
_ = f.Close()
_ = os.Remove(tmp)
return fmt.Errorf("uffd: write hot pages count: %w", err)
}
for _, p := range pages {
n = binary.PutUvarint(ibuf[:], uint64(p.Region))
if _, err := bw.Write(ibuf[:n]); err != nil {
_ = f.Close()
_ = os.Remove(tmp)
return fmt.Errorf("uffd: write hot pages region: %w", err)
}
n = binary.PutUvarint(ibuf[:], p.PageOffset)
if _, err := bw.Write(ibuf[:n]); err != nil {
_ = f.Close()
_ = os.Remove(tmp)
return fmt.Errorf("uffd: write hot pages offset: %w", err)
}
}
if err := bw.Flush(); err != nil {
_ = f.Close()
_ = os.Remove(tmp)
return fmt.Errorf("uffd: flush hot pages: %w", err)
}
if err := f.Close(); err != nil {
_ = os.Remove(tmp)
return fmt.Errorf("uffd: close hot pages tmp: %w", err)
}
if err := os.Rename(tmp, path); err != nil {
return fmt.Errorf("uffd: rename hot pages: %w", err)
}
return nil
}

// LoadHotPageList reads a HotPageList from path. Returns an empty list
// (not an error) when path does not exist; the absence of a baked
// hot-page file simply means "don't prefetch."
func LoadHotPageList(path string) (*HotPageList, error) {
clean := filepath.Clean(path)
data, err := os.ReadFile(clean)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return &HotPageList{}, nil
}
return nil, fmt.Errorf("uffd: read hot pages: %w", err)
}
if len(data) < len(hotPagesFileMagic) {
return nil, errors.New("uffd: hot pages file truncated")
}
if string(data[:len(hotPagesFileMagic)]) != string(hotPagesFileMagic) {
return nil, errors.New("uffd: hot pages file has bad magic")
}
rest := data[len(hotPagesFileMagic):]
count, n := binary.Uvarint(rest)
if n <= 0 {
return nil, errors.New("uffd: hot pages file has bad count")
}
rest = rest[n:]
out := &HotPageList{pages: make([]HotPage, 0, count)}
for i := uint64(0); i < count; i++ {
region, n := binary.Uvarint(rest)
if n <= 0 {
return nil, fmt.Errorf("uffd: hot pages file truncated at entry %d (region)", i)
}
rest = rest[n:]
offset, n := binary.Uvarint(rest)
if n <= 0 {
return nil, fmt.Errorf("uffd: hot pages file truncated at entry %d (offset)", i)
}
rest = rest[n:]
out.pages = append(out.pages, HotPage{Region: uint32(region), PageOffset: offset})
}
if len(rest) != 0 {
return nil, fmt.Errorf("uffd: hot pages file has %d trailing bytes", len(rest))
}
return out, nil
}
66 changes: 66 additions & 0 deletions lib/uffd/hotpages_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package uffd

import (
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestHotPageList_SnapshotSortsAndDedups(t *testing.T) {
var l HotPageList
l.Add(HotPage{Region: 1, PageOffset: 8192})
l.Add(HotPage{Region: 0, PageOffset: 4096})
l.Add(HotPage{Region: 0, PageOffset: 4096}) // duplicate
l.Add(HotPage{Region: 0, PageOffset: 0})

got := l.Snapshot()
want := []HotPage{
{Region: 0, PageOffset: 0},
{Region: 0, PageOffset: 4096},
{Region: 1, PageOffset: 8192},
}
assert.Equal(t, want, got)
}

func TestHotPageList_SaveLoadRoundTrip(t *testing.T) {
var l HotPageList
l.Add(HotPage{Region: 0, PageOffset: 0})
l.Add(HotPage{Region: 0, PageOffset: 4096})
l.Add(HotPage{Region: 2, PageOffset: 1 << 20})

path := filepath.Join(t.TempDir(), "hot.bin")
require.NoError(t, l.Save(path))

got, err := LoadHotPageList(path)
require.NoError(t, err)
assert.Equal(t, l.Snapshot(), got.Snapshot())
}

func TestLoadHotPageList_MissingReturnsEmpty(t *testing.T) {
got, err := LoadHotPageList(filepath.Join(t.TempDir(), "absent.bin"))
require.NoError(t, err)
assert.Equal(t, 0, got.Len())
}

func TestLoadHotPageList_BadMagic(t *testing.T) {
path := filepath.Join(t.TempDir(), "bad.bin")
require.NoError(t, writeFile(path, []byte("XXXX\x00")))
_, err := LoadHotPageList(path)
assert.Error(t, err)
}

func TestLoadHotPageList_TruncatedAtEntry(t *testing.T) {
path := filepath.Join(t.TempDir(), "trunc.bin")
// magic + count=2 + only one entry
data := append([]byte("HPL1"), 0x02, 0x00, 0x00) // count=2, region=0, offset=0
require.NoError(t, writeFile(path, data))
_, err := LoadHotPageList(path)
assert.Error(t, err)
}

func writeFile(path string, data []byte) error {
return os.WriteFile(path, data, 0o600)
}
54 changes: 52 additions & 2 deletions lib/uffd/server_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ func (s *Server) startListener(ctx context.Context, forkID string, socketPath st
}
}

s.installPrefetcher(forkID, func(list *HotPageList) error {
return s.prefetchInto(fd, regions, list)
})

s.servePageFaults(hctx, fd, regions, forkID)
}()

Expand Down Expand Up @@ -272,12 +276,13 @@ func (s *Server) copyPageForFault(fd int, regions []MemoryRegion, addr uint64, p
pageSize := uint64(s.pageSize)
pageStart := addr &^ (pageSize - 1)

for _, r := range regions {
for idx, r := range regions {
base := uint64(r.BaseHostAddr)
if pageStart < base || pageStart >= base+r.Size {
continue
}
offset := int64(r.MemFileOffset + (pageStart - base))
regionOff := pageStart - base
offset := int64(r.MemFileOffset + regionOff)
if _, err := s.memFile.ReadAt(page, offset); err != nil && !errors.Is(err, io.EOF) {
return fmt.Errorf("uffd: read template at %d: %w", offset, err)
}
Expand All @@ -294,11 +299,56 @@ func (s *Server) copyPageForFault(fd int, regions []MemoryRegion, addr uint64, p
}
return fmt.Errorf("uffd: UFFDIO_COPY: %w", err)
}
if s.cfg.RecordHotPages {
s.hotPages.Add(HotPage{Region: uint32(idx), PageOffset: regionOff})
}
return nil
}
return fmt.Errorf("uffd: fault addr 0x%x outside any registered region", addr)
}

// prefetchInto walks list and issues a UFFDIO_COPY for each entry
// against the supplied fork's userfaultfd. It tolerates EEXIST/EAGAIN
// the same way the fault handler does so a racing first-touch fault
// from a vCPU does not abort the whole prefetch.
func (s *Server) prefetchInto(fd int, regions []MemoryRegion, list *HotPageList) error {
if list == nil {
return nil
}
pages := list.Snapshot()
if len(pages) == 0 {
return nil
}
page := make([]byte, s.pageSize)
pageSize := uint64(s.pageSize)
for _, hp := range pages {
if int(hp.Region) >= len(regions) {
return fmt.Errorf("uffd: prefetch entry refers to region %d (only %d registered)", hp.Region, len(regions))
}
r := regions[hp.Region]
if hp.PageOffset+pageSize > r.Size {
return fmt.Errorf("uffd: prefetch offset %d outside region %d size %d", hp.PageOffset, hp.Region, r.Size)
}
dst := uint64(r.BaseHostAddr) + hp.PageOffset
src := int64(r.MemFileOffset + hp.PageOffset)
if _, err := s.memFile.ReadAt(page, src); err != nil && !errors.Is(err, io.EOF) {
return fmt.Errorf("uffd: prefetch read template at %d: %w", src, err)
}
copyArg := uffdioCopyArg{
Dst: dst,
Src: uint64(uintptr(unsafe.Pointer(&page[0]))),
Len: pageSize,
}
if err := ioctl(fd, uffdioCopyIoctl, unsafe.Pointer(&copyArg)); err != nil {
if errors.Is(err, syscall.EEXIST) || errors.Is(err, syscall.EAGAIN) {
continue
}
return fmt.Errorf("uffd: prefetch UFFDIO_COPY: %w", err)
}
}
return nil
}

func ioctl(fd int, req uintptr, arg unsafe.Pointer) error {
_, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), req, uintptr(arg))
if errno != 0 {
Expand Down
Loading
Loading