diff --git a/lib/hypervisor/cloudhypervisor/cloudhypervisor.go b/lib/hypervisor/cloudhypervisor/cloudhypervisor.go index 95e0792a..81a61032 100644 --- a/lib/hypervisor/cloudhypervisor/cloudhypervisor.go +++ b/lib/hypervisor/cloudhypervisor/cloudhypervisor.go @@ -15,6 +15,7 @@ import ( type CloudHypervisor struct { client *vmm.VMM socketPath string + serialLog *serialSocketLogger } var balloonTargetCache hypervisor.BalloonTargetCache @@ -72,6 +73,8 @@ func (c *CloudHypervisor) DeleteVM(ctx context.Context) error { // Shutdown stops the VMM process gracefully. func (c *CloudHypervisor) Shutdown(ctx context.Context) error { + defer c.serialLog.Close() + resp, err := c.client.ShutdownVMMWithResponse(ctx) if err != nil { return fmt.Errorf("shutdown vmm: %w", err) diff --git a/lib/hypervisor/cloudhypervisor/config.go b/lib/hypervisor/cloudhypervisor/config.go index d728036b..ba129887 100644 --- a/lib/hypervisor/cloudhypervisor/config.go +++ b/lib/hypervisor/cloudhypervisor/config.go @@ -1,10 +1,14 @@ package cloudhypervisor import ( + "path/filepath" + "github.com/kernel/hypeman/lib/hypervisor" "github.com/kernel/hypeman/lib/vmm" ) +const cloudHypervisorSerialSocketName = "serial.sock" + // ToVMConfig converts hypervisor.VMConfig to Cloud Hypervisor's vmm.VmConfig. func ToVMConfig(cfg hypervisor.VMConfig) vmm.VmConfig { // Payload configuration (kernel + initramfs) @@ -66,10 +70,16 @@ func ToVMConfig(cfg hypervisor.VMConfig) vmm.VmConfig { disks = append(disks, disk) } - // Serial console configuration + // Serial console configuration. Cloud Hypervisor opens File mode without + // O_APPEND, so use Socket mode and let hypeman own the append-mode log fd. serial := vmm.ConsoleConfig{ - Mode: vmm.ConsoleConfigMode("File"), - File: ptr(cfg.SerialLogPath), + Mode: vmm.ConsoleConfigModeNull, + } + if cfg.SerialLogPath != "" { + serial = vmm.ConsoleConfig{ + Mode: vmm.ConsoleConfigModeSocket, + Socket: ptr(serialSocketPathForLog(cfg.SerialLogPath)), + } } // Console off (we use serial) @@ -139,3 +149,15 @@ func ToVMConfig(cfg hypervisor.VMConfig) vmm.VmConfig { Balloon: balloon, } } + +func serialSocketPathForLog(logPath string) string { + dir := filepath.Dir(logPath) + if filepath.Base(dir) == "logs" { + dir = filepath.Dir(dir) + } + return filepath.Join(dir, cloudHypervisorSerialSocketName) +} + +func appLogPathForSerialSocket(socketPath string) string { + return filepath.Join(filepath.Dir(socketPath), "logs", "app.log") +} diff --git a/lib/hypervisor/cloudhypervisor/config_test.go b/lib/hypervisor/cloudhypervisor/config_test.go index b5cdb96e..f149c68f 100644 --- a/lib/hypervisor/cloudhypervisor/config_test.go +++ b/lib/hypervisor/cloudhypervisor/config_test.go @@ -27,3 +27,31 @@ func TestToVMConfig_GuestMemoryBalloon(t *testing.T) { require.NotNil(t, vmCfg.Balloon.FreePageReporting) assert.True(t, *vmCfg.Balloon.FreePageReporting) } + +func TestToVMConfig_SerialUsesSocket(t *testing.T) { + cfg := hypervisor.VMConfig{ + VCPUs: 1, + MemoryBytes: 512 * 1024 * 1024, + SerialLogPath: "/var/lib/hypeman/guests/test/logs/app.log", + } + + vmCfg := ToVMConfig(cfg) + require.NotNil(t, vmCfg.Serial) + assert.Equal(t, "Socket", string(vmCfg.Serial.Mode)) + require.NotNil(t, vmCfg.Serial.Socket) + assert.Equal(t, "/var/lib/hypeman/guests/test/serial.sock", *vmCfg.Serial.Socket) + assert.Nil(t, vmCfg.Serial.File) +} + +func TestToVMConfig_SerialNullWhenNoLogPath(t *testing.T) { + cfg := hypervisor.VMConfig{ + VCPUs: 1, + MemoryBytes: 512 * 1024 * 1024, + } + + vmCfg := ToVMConfig(cfg) + require.NotNil(t, vmCfg.Serial) + assert.Equal(t, "Null", string(vmCfg.Serial.Mode)) + assert.Nil(t, vmCfg.Serial.Socket) + assert.Nil(t, vmCfg.Serial.File) +} diff --git a/lib/hypervisor/cloudhypervisor/fork_snapshot.go b/lib/hypervisor/cloudhypervisor/fork_snapshot.go index 831c600e..a8f6077f 100644 --- a/lib/hypervisor/cloudhypervisor/fork_snapshot.go +++ b/lib/hypervisor/cloudhypervisor/fork_snapshot.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/kernel/hypeman/lib/hypervisor" + "github.com/kernel/hypeman/lib/vmm" ) // rewriteSnapshotConfigForFork rewrites Cloud Hypervisor snapshot config.json for a forked instance. @@ -88,7 +89,9 @@ func updateSerialConfig(config map[string]any, logPath string) { if !ok || serial == nil { return } - serial["file"] = logPath + delete(serial, "file") + serial["mode"] = string(vmm.ConsoleConfigModeSocket) + serial["socket"] = serialSocketPathForLog(logPath) } func updateNetworkConfig(config map[string]any, netCfg *hypervisor.ForkNetworkConfig) { diff --git a/lib/hypervisor/cloudhypervisor/fork_snapshot_test.go b/lib/hypervisor/cloudhypervisor/fork_snapshot_test.go index b25b1034..2ccfc87c 100644 --- a/lib/hypervisor/cloudhypervisor/fork_snapshot_test.go +++ b/lib/hypervisor/cloudhypervisor/fork_snapshot_test.go @@ -17,7 +17,7 @@ func TestRewriteSnapshotConfigForFork(t *testing.T) { orig := map[string]any{ "disks": []any{map[string]any{"path": "/src/guests/a/overlay.raw"}}, - "serial": map[string]any{"file": "/src/guests/a/logs/app.log"}, + "serial": map[string]any{"mode": "File", "file": "/src/guests/a/logs/app.log"}, "vsock": map[string]any{"cid": float64(100), "socket": "/src/guests/a/vsock.sock"}, "metadata": map[string]any{ "note": "keep-/src/guests/a-as-substring", @@ -59,7 +59,9 @@ func TestRewriteSnapshotConfigForFork(t *testing.T) { assert.Equal(t, "/dst/guests/b/overlay.raw", disk0["path"]) serial := updated["serial"].(map[string]any) - assert.Equal(t, "/dst/guests/b/logs/app.log", serial["file"]) + assert.Equal(t, "Socket", serial["mode"]) + assert.Equal(t, "/dst/guests/b/serial.sock", serial["socket"]) + assert.NotContains(t, serial, "file") vsock := updated["vsock"].(map[string]any) assert.Equal(t, float64(100), vsock["cid"]) diff --git a/lib/hypervisor/cloudhypervisor/process.go b/lib/hypervisor/cloudhypervisor/process.go index 0618de43..0d8ee77b 100644 --- a/lib/hypervisor/cloudhypervisor/process.go +++ b/lib/hypervisor/cloudhypervisor/process.go @@ -2,6 +2,7 @@ package cloudhypervisor import ( "context" + "encoding/json" "errors" "fmt" "log/slog" @@ -91,6 +92,12 @@ func (s *Starter) StartVM(ctx context.Context, p *paths.Paths, version string, s // 3. Configure the VM via HTTP API vmConfig := ToVMConfig(config) + serialSocketPath, serialLogPath := serialLogPathsFromVMConfig(vmConfig) + if serialSocketPath != "" { + if err := removeStaleSerialSocket(serialSocketPath); err != nil { + return 0, nil, err + } + } resp, err := hv.client.CreateVMWithResponse(ctx, vmConfig) if err != nil { logStartVMFailureDiagnostics(ctx, log, socketPath, pid, "create_vm", err, 0, "") @@ -100,14 +107,23 @@ func (s *Starter) StartVM(ctx context.Context, p *paths.Paths, version string, s logStartVMFailureDiagnostics(ctx, log, socketPath, pid, "create_vm", nil, resp.StatusCode(), string(resp.Body)) return 0, nil, fmt.Errorf("create vm failed with status %d: %s", resp.StatusCode(), string(resp.Body)) } + if serialSocketPath != "" { + serialLog, err := startSerialSocketLogger(ctx, serialSocketPath, serialLogPath) + if err != nil { + return 0, nil, fmt.Errorf("start serial logger: %w", err) + } + hv.serialLog = serialLog + } // 4. Boot the VM via HTTP API bootResp, err := hv.client.BootVMWithResponse(ctx) if err != nil { + hv.serialLog.Close() logStartVMFailureDiagnostics(ctx, log, socketPath, pid, "boot_vm", err, 0, "") return 0, nil, fmt.Errorf("boot vm: %w", err) } if bootResp.StatusCode() != 204 { + hv.serialLog.Close() logStartVMFailureDiagnostics(ctx, log, socketPath, pid, "boot_vm", nil, bootResp.StatusCode(), string(bootResp.Body)) return 0, nil, fmt.Errorf("boot vm failed with status %d: %s", bootResp.StatusCode(), string(bootResp.Body)) } @@ -151,6 +167,13 @@ func (s *Starter) RestoreVM(ctx context.Context, p *paths.Paths, version string, return 0, nil, fmt.Errorf("create client: %w", err) } + serialSocketPath, serialLogPath := serialLogPathsFromSnapshot(snapshotPath) + if serialSocketPath != "" { + if err := removeStaleSerialSocket(serialSocketPath); err != nil { + return 0, nil, err + } + } + // 3. Restore from snapshot via HTTP API restoreAPIStart := time.Now() sourceURL := "file://" + snapshotPath @@ -166,6 +189,13 @@ func (s *Starter) RestoreVM(ctx context.Context, p *paths.Paths, version string, return 0, nil, fmt.Errorf("restore failed with status %d: %s", resp.StatusCode(), string(resp.Body)) } log.DebugContext(ctx, "CH restore API complete", "duration_ms", time.Since(restoreAPIStart).Milliseconds()) + if serialSocketPath != "" { + serialLog, err := startSerialSocketLogger(ctx, serialSocketPath, serialLogPath) + if err != nil { + return 0, nil, fmt.Errorf("start serial logger: %w", err) + } + hv.serialLog = serialLog + } // Success - release cleanup to prevent killing the process cu.Release() @@ -177,6 +207,42 @@ func ptr[T any](v T) *T { return &v } +func serialLogPathsFromVMConfig(config vmm.VmConfig) (socketPath, logPath string) { + if config.Serial == nil || config.Serial.Socket == nil || config.Serial.Mode != vmm.ConsoleConfigModeSocket { + return "", "" + } + return *config.Serial.Socket, appLogPathForSerialSocket(*config.Serial.Socket) +} + +func serialLogPathsFromSnapshot(snapshotPath string) (socketPath, logPath string) { + data, err := os.ReadFile(filepath.Join(snapshotPath, "config.json")) + if err != nil { + return "", "" + } + + var config map[string]any + if err := json.Unmarshal(data, &config); err != nil { + return "", "" + } + + serial, ok := config["serial"].(map[string]any) + if !ok || serial == nil || serial["mode"] != string(vmm.ConsoleConfigModeSocket) { + return "", "" + } + socketPath, _ = serial["socket"].(string) + if socketPath == "" { + return "", "" + } + return socketPath, appLogPathForSerialSocket(socketPath) +} + +func removeStaleSerialSocket(socketPath string) error { + if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("remove stale serial socket: %w", err) + } + return nil +} + func logStartVMFailureDiagnostics(ctx context.Context, log *slog.Logger, socketPath string, pid int, operation string, requestErr error, statusCode int, responseBody string) { if log == nil { return diff --git a/lib/hypervisor/cloudhypervisor/serial_log.go b/lib/hypervisor/cloudhypervisor/serial_log.go new file mode 100644 index 00000000..23255948 --- /dev/null +++ b/lib/hypervisor/cloudhypervisor/serial_log.go @@ -0,0 +1,89 @@ +package cloudhypervisor + +import ( + "context" + "fmt" + "io" + "net" + "os" + "path/filepath" + "sync" + "time" +) + +const serialSocketConnectTimeout = 5 * time.Second + +type serialSocketLogger struct { + conn net.Conn + file *os.File + done chan struct{} + once sync.Once +} + +func startSerialSocketLogger(ctx context.Context, socketPath, logPath string) (*serialSocketLogger, error) { + if err := os.MkdirAll(filepath.Dir(logPath), 0755); err != nil { + return nil, fmt.Errorf("create serial log directory: %w", err) + } + + logFile, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + return nil, fmt.Errorf("open serial log: %w", err) + } + + conn, err := dialSerialSocket(ctx, socketPath) + if err != nil { + logFile.Close() + return nil, err + } + + logger := &serialSocketLogger{ + conn: conn, + file: logFile, + done: make(chan struct{}), + } + go logger.copy() + return logger, nil +} + +func dialSerialSocket(ctx context.Context, socketPath string) (net.Conn, error) { + ctx, cancel := context.WithTimeout(ctx, serialSocketConnectTimeout) + defer cancel() + + dialer := net.Dialer{} + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + var lastErr error + for { + conn, err := dialer.DialContext(ctx, "unix", socketPath) + if err == nil { + return conn, nil + } + lastErr = err + + select { + case <-ctx.Done(): + return nil, fmt.Errorf("connect serial socket %s: %w", socketPath, lastErr) + case <-ticker.C: + } + } +} + +func (l *serialSocketLogger) copy() { + defer close(l.done) + _, _ = io.Copy(l.file, l.conn) +} + +// Close terminates the serial logger. It closes the connection (unblocking +// io.Copy), waits for the copy goroutine to finish, then closes the log file. +// Safe to call on a nil receiver and idempotent. +func (l *serialSocketLogger) Close() { + if l == nil { + return + } + l.once.Do(func() { + _ = l.conn.Close() + <-l.done + _ = l.file.Close() + }) +} diff --git a/lib/hypervisor/cloudhypervisor/serial_log_test.go b/lib/hypervisor/cloudhypervisor/serial_log_test.go new file mode 100644 index 00000000..cc1ddee7 --- /dev/null +++ b/lib/hypervisor/cloudhypervisor/serial_log_test.go @@ -0,0 +1,167 @@ +package cloudhypervisor + +import ( + "context" + "net" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestSerialSocketLoggerWritesWithAppendAfterTruncate(t *testing.T) { + tmp := socketTempDir(t) + socketPath := filepath.Join(tmp, "serial.sock") + logPath := filepath.Join(tmp, "logs", "app.log") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + logger, err := startSerialSocketLogger(t.Context(), socketPath, logPath) + require.NoError(t, err) + defer logger.Close() + + conn, err := listener.Accept() + require.NoError(t, err) + defer conn.Close() + + _, err = conn.Write([]byte("first\n")) + require.NoError(t, err) + requireEventuallyFileContent(t, logPath, "first\n") + + require.NoError(t, os.Truncate(logPath, 0)) + _, err = conn.Write([]byte("second\n")) + require.NoError(t, err) + requireEventuallyFileContent(t, logPath, "second\n") +} + +func TestDialSerialSocketRetriesUntilAvailable(t *testing.T) { + tmp := socketTempDir(t) + socketPath := filepath.Join(tmp, "serial.sock") + + ctx, cancel := context.WithTimeout(t.Context(), time.Second) + defer cancel() + + connected := make(chan net.Conn, 1) + errCh := make(chan error, 1) + go func() { + conn, err := dialSerialSocket(ctx, socketPath) + if err != nil { + errCh <- err + return + } + connected <- conn + }() + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + serverConn, err := listener.Accept() + require.NoError(t, err) + defer serverConn.Close() + + select { + case conn := <-connected: + defer conn.Close() + case err := <-errCh: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("timed out waiting for serial socket dial") + } +} + +func TestSerialSocketLoggerCloseIsIdempotent(t *testing.T) { + tmp := socketTempDir(t) + socketPath := filepath.Join(tmp, "serial.sock") + logPath := filepath.Join(tmp, "logs", "app.log") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + logger, err := startSerialSocketLogger(t.Context(), socketPath, logPath) + require.NoError(t, err) + + conn, err := listener.Accept() + require.NoError(t, err) + defer conn.Close() + + _, err = conn.Write([]byte("before-close\n")) + require.NoError(t, err) + requireEventuallyFileContent(t, logPath, "before-close\n") + + logger.Close() + logger.Close() + + select { + case <-logger.done: + case <-time.After(time.Second): + t.Fatal("serial logger did not stop after Close") + } +} + +func TestSerialLogPathsFromSnapshot(t *testing.T) { + tmp := t.TempDir() + socketPath := filepath.Join("/var/lib/hypeman/guests/test", cloudHypervisorSerialSocketName) + data := []byte(`{"serial":{"mode":"Socket","socket":"` + socketPath + `"}}`) + require.NoError(t, os.WriteFile(filepath.Join(tmp, "config.json"), data, 0644)) + + gotSocket, gotLog := serialLogPathsFromSnapshot(tmp) + require.Equal(t, socketPath, gotSocket) + require.Equal(t, "/var/lib/hypeman/guests/test/logs/app.log", gotLog) +} + +func TestSerialLogPathsFromSnapshotIgnoresUnsupportedConfig(t *testing.T) { + tests := []struct { + name string + data string + }{ + {name: "file mode", data: `{"serial":{"mode":"File","file":"/tmp/app.log"}}`}, + {name: "missing socket", data: `{"serial":{"mode":"Socket"}}`}, + {name: "invalid json", data: `{`}, + {name: "missing serial", data: `{}`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmp := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(tmp, "config.json"), []byte(tt.data), 0644)) + + gotSocket, gotLog := serialLogPathsFromSnapshot(tmp) + require.Empty(t, gotSocket) + require.Empty(t, gotLog) + }) + } +} + +func TestRemoveStaleSerialSocket(t *testing.T) { + tmp := socketTempDir(t) + socketPath := filepath.Join(tmp, "serial.sock") + + require.NoError(t, os.WriteFile(socketPath, []byte("stale"), 0644)) + require.FileExists(t, socketPath) + + require.NoError(t, removeStaleSerialSocket(socketPath)) + require.NoFileExists(t, socketPath) + require.NoError(t, removeStaleSerialSocket(socketPath)) +} + +func socketTempDir(t *testing.T) string { + t.Helper() + tmp, err := os.MkdirTemp("/tmp", "chlog-*") + require.NoError(t, err) + t.Cleanup(func() { _ = os.RemoveAll(tmp) }) + return tmp +} + +func requireEventuallyFileContent(t *testing.T, path, want string) { + t.Helper() + require.Eventually(t, func() bool { + data, err := os.ReadFile(path) + return err == nil && string(data) == want + }, time.Second, 10*time.Millisecond) +}