diff --git a/cmd/root.go b/cmd/root.go index c703bd25..3e18a2fc 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -78,7 +78,7 @@ func NewRootCmd(cfg *env.Env, tel *telemetry.Client, logger log.Logger) *cobra.C newUpdateCmd(cfg), newDocsCmd(), newAWSCmd(cfg), - newSnapshotCmd(cfg), + newSnapshotCmd(cfg, tel, logger), newResetCmd(cfg), ) diff --git a/cmd/snapshot.go b/cmd/snapshot.go index 3e54169f..5bff9aca 100644 --- a/cmd/snapshot.go +++ b/cmd/snapshot.go @@ -1,30 +1,135 @@ package cmd import ( + "context" "fmt" "os" "time" "github.com/localstack/lstk/internal/config" + "github.com/localstack/lstk/internal/container" "github.com/localstack/lstk/internal/emulator/aws" "github.com/localstack/lstk/internal/endpoint" "github.com/localstack/lstk/internal/env" + "github.com/localstack/lstk/internal/log" "github.com/localstack/lstk/internal/output" "github.com/localstack/lstk/internal/runtime" "github.com/localstack/lstk/internal/snapshot" + "github.com/localstack/lstk/internal/telemetry" "github.com/localstack/lstk/internal/ui" "github.com/spf13/cobra" ) -func newSnapshotCmd(cfg *env.Env) *cobra.Command { +func newSnapshotCmd(cfg *env.Env, tel *telemetry.Client, logger log.Logger) *cobra.Command { cmd := &cobra.Command{ Use: "snapshot", Short: "Manage emulator snapshots", } cmd.AddCommand(newSnapshotSaveCmd(cfg)) + cmd.AddCommand(newSnapshotLoadCmd(cfg, tel, logger)) return cmd } +func buildStarter(cfg *env.Env, rt runtime.Runtime, appConfig *config.Config, logger log.Logger, tel *telemetry.Client) snapshot.Starter { + return func(ctx context.Context, sink output.Sink) error { + opts := buildStartOptions(cfg, appConfig, logger, tel, false) + return container.Start(ctx, rt, sink, opts, false) + } +} + +func newSnapshotLoadCmd(cfg *env.Env, tel *telemetry.Client, logger log.Logger) *cobra.Command { + cmd := &cobra.Command{ + Use: "load REF", + Short: "Load a snapshot into the running emulator", + Long: `Load a snapshot into the running emulator, starting it first if needed. + +REF identifies the snapshot to load: + + lstk snapshot load my-baseline # loads ./my-baseline or ./my-baseline.zip + lstk snapshot load ./checkpoint.zip # loads from explicit path + lstk snapshot load pod:my-baseline # loads from LocalStack Cloud + +Merge strategies control how snapshot state is combined with running state: + + --merge=account-region-merge (default) snapshot wins on (service, account, region) overlap + --merge=overwrite wipe running state, then load + --merge=service-merge snapshot wins per-resource; non-overlapping resources combined`, + Args: cobra.ExactArgs(1), + PreRunE: initConfig(nil), + RunE: func(cmd *cobra.Command, args []string) error { + dryRun, err := cmd.Flags().GetBool("dry-run") + if err != nil { + return err + } + if dryRun { + return fmt.Errorf("--dry-run is not yet implemented") + } + + strategy, err := cmd.Flags().GetString("merge") + if err != nil { + return err + } + + src, err := snapshot.ParseSource(args[0]) + if err != nil { + return err + } + + if err := snapshot.ValidateMergeStrategy(strategy); err != nil { + return err + } + + rt, client, host, containers, appConfig, err := resolveSnapshotDeps(cmd.Context(), cfg) + if err != nil { + return err + } + + starter := buildStarter(cfg, rt, appConfig, logger, tel) + + if isInteractiveMode(cfg) { + return ui.RunSnapshotLoad(cmd.Context(), rt, containers, client, host, src, cfg.AuthToken, strategy, starter) + } + sink := output.NewPlainSink(os.Stdout) + switch src.Kind { + case snapshot.KindPod: + return snapshot.LoadPod(cmd.Context(), rt, containers, client, host, src.Value, cfg.AuthToken, strategy, starter, sink) + default: + return snapshot.LoadLocal(cmd.Context(), rt, containers, client, host, src.Value, strategy, starter, sink) + } + }, + } + cmd.Flags().String("merge", snapshot.MergeStrategyAccountRegion, "Merge strategy: overwrite, account-region-merge, service-merge") + cmd.Flags().Bool("dry-run", false, "Preview changes without applying (not yet implemented)") + return cmd +} + +func resolveSnapshotDeps(ctx context.Context, cfg *env.Env) (rt runtime.Runtime, client *aws.Client, host string, containers []config.ContainerConfig, appConfig *config.Config, err error) { + appConfig, err = config.Get() + if err != nil { + return nil, nil, "", nil, nil, fmt.Errorf("failed to get config: %w", err) + } + + var awsContainer config.ContainerConfig + var found bool + for _, c := range appConfig.Containers { + if c.Type == config.EmulatorAWS { + awsContainer = c + found = true + break + } + } + if !found { + return nil, nil, "", nil, nil, fmt.Errorf("snapshot is only supported for the AWS emulator") + } + + rt, err = runtime.NewDockerRuntime(cfg.DockerHost) + if err != nil { + return nil, nil, "", nil, nil, err + } + host, _ = endpoint.ResolveHost(ctx, awsContainer.Port, cfg.LocalStackHost) + return rt, aws.NewClient(), host, []config.ContainerConfig{awsContainer}, appConfig, nil +} + func newSnapshotSaveCmd(cfg *env.Env) *cobra.Command { return &cobra.Command{ Use: "save [destination]", @@ -53,31 +158,10 @@ To save to a remote pod on the LocalStack platform, use the pod: prefix: return err } - appConfig, err := config.Get() - if err != nil { - return fmt.Errorf("failed to get config: %w", err) - } - - var awsContainer config.ContainerConfig - var found bool - for _, c := range appConfig.Containers { - if c.Type == config.EmulatorAWS { - awsContainer = c - found = true - break - } - } - if !found { - return fmt.Errorf("snapshot is only supported for the AWS emulator") - } - - rt, err := runtime.NewDockerRuntime(cfg.DockerHost) + rt, client, host, containers, _, err := resolveSnapshotDeps(cmd.Context(), cfg) if err != nil { return err } - host, _ := endpoint.ResolveHost(cmd.Context(), awsContainer.Port, cfg.LocalStackHost) - client := aws.NewClient() - containers := []config.ContainerConfig{awsContainer} if isInteractiveMode(cfg) { return ui.RunSnapshotSave(cmd.Context(), rt, containers, client, host, dest, cfg.AuthToken) diff --git a/internal/emulator/aws/client.go b/internal/emulator/aws/client.go index 6b5da841..99b6e4f1 100644 --- a/internal/emulator/aws/client.go +++ b/internal/emulator/aws/client.go @@ -178,6 +178,121 @@ func (c *Client) ExportState(ctx context.Context, host string, dst io.Writer) er return nil } +func (c *Client) ImportState(ctx context.Context, host string, src io.Reader, strategy string) error { + url := fmt.Sprintf("http://%s/_localstack/pods", host) + if strategy != "" { + url += "?merge=" + strategy + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, src) + if err != nil { + return fmt.Errorf("create request: %w", err) + } + req.Header.Set("Content-Type", "application/octet-stream") + + resp, err := c.http.Do(req) + if err != nil { + return fmt.Errorf("connect to LocalStack: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode == http.StatusUnprocessableEntity { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("snapshot is incompatible with the running LocalStack version: %s", strings.TrimSpace(string(body))) + } + if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("LocalStack returned status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + scanner := bufio.NewScanner(resp.Body) + buf := make([]byte, 1024*1024) + scanner.Buffer(buf, 1024*1024) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + var event struct { + Service string `json:"service"` + Status string `json:"status"` + Message string `json:"message"` + } + if err := json.Unmarshal([]byte(line), &event); err != nil { + continue + } + if event.Status == "error" && event.Message != "" { + return fmt.Errorf("load failed for service %s: %s", event.Service, event.Message) + } + } + return scanner.Err() +} + +func (c *Client) LoadPodSnapshot(ctx context.Context, host, podName, authToken, strategy string) ([]string, error) { + url := fmt.Sprintf("http://%s/_localstack/pods/%s", host, podName) + if strategy != "" { + url += "?merge=" + strategy + } + req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, bytes.NewReader([]byte("{}"))) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(":"+authToken))) + + resp, err := c.http.Do(req) + if err != nil { + return nil, fmt.Errorf("connect to LocalStack: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode == http.StatusUnprocessableEntity { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("snapshot is incompatible with the running LocalStack version: %s", strings.TrimSpace(string(body))) + } + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("pod load failed (HTTP %d): %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + var services []string + scanner := bufio.NewScanner(resp.Body) + buf := make([]byte, 1024*1024) + scanner.Buffer(buf, 1024*1024) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + var event struct { + Event string `json:"event"` + Service string `json:"service"` + Status string `json:"status"` + Message string `json:"message"` + } + if err := json.Unmarshal([]byte(line), &event); err != nil { + continue + } + switch event.Event { + case "service": + switch event.Status { + case "ok": + services = append(services, event.Service) + case "error": + return nil, fmt.Errorf("load failed for service %s: %s", event.Service, event.Message) + } + case "completion": + if event.Status != "ok" { + return nil, fmt.Errorf("pod load failed: %s", event.Message) + } + return services, nil + } + } + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("reading response: %w", err) + } + return services, nil +} + func (c *Client) SavePodSnapshot(ctx context.Context, host, podName, authToken string) (snapshot.PodSaveResult, error) { url := fmt.Sprintf("http://%s/_localstack/pods/%s", host, podName) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader([]byte("{}"))) diff --git a/internal/output/events.go b/internal/output/events.go index 098466b6..260ce825 100644 --- a/internal/output/events.go +++ b/internal/output/events.go @@ -84,6 +84,11 @@ type PodSnapshotSavedEvent struct { Size int64 } +type SnapshotLoadedEvent struct { + Source string // display source shown to the user (e.g. "./snap.zip" or "pod:my-baseline") + Services []string // services restored +} + type AuthCompleteEvent struct{} // Event is a sealed marker — only event types in this package implement it, @@ -98,7 +103,8 @@ func (AuthCompleteEvent) sealedEvent() {} func (InstanceInfoEvent) sealedEvent() {} func (TableEvent) sealedEvent() {} func (ResourceSummaryEvent) sealedEvent() {} -func (PodSnapshotSavedEvent) sealedEvent() {} +func (PodSnapshotSavedEvent) sealedEvent() {} +func (SnapshotLoadedEvent) sealedEvent() {} func (ContainerStatusEvent) sealedEvent() {} func (ProgressEvent) sealedEvent() {} func (UserInputRequestEvent) sealedEvent() {} diff --git a/internal/output/plain_format.go b/internal/output/plain_format.go index 7645f284..c565c7b6 100644 --- a/internal/output/plain_format.go +++ b/internal/output/plain_format.go @@ -42,6 +42,8 @@ func FormatEventLine(event Event) (string, bool) { return formatResourceSummary(e), true case PodSnapshotSavedEvent: return formatPodSnapshotSaved(e), true + case SnapshotLoadedEvent: + return formatSnapshotLoaded(e), true case AuthCompleteEvent: return "", false default: @@ -198,6 +200,15 @@ func formatResourceSummary(e ResourceSummaryEvent) string { return fmt.Sprintf("~ %d resources · %d services", e.Resources, e.Services) } +func formatSnapshotLoaded(e SnapshotLoadedEvent) string { + var sb strings.Builder + sb.WriteString(SuccessMarker() + fmt.Sprintf(" Snapshot loaded from %s", e.Source)) + if len(e.Services) > 0 { + sb.WriteString("\n• Services: " + strings.Join(e.Services, ", ")) + } + return sb.String() +} + func formatPodSnapshotSaved(e PodSnapshotSavedEvent) string { var sb strings.Builder sb.WriteString(SuccessMarker() + fmt.Sprintf(" Snapshot saved to pod:%s", e.PodName)) diff --git a/internal/output/plain_format_test.go b/internal/output/plain_format_test.go index eb629871..7e269e24 100644 --- a/internal/output/plain_format_test.go +++ b/internal/output/plain_format_test.go @@ -210,6 +210,32 @@ func TestFormatEventLine(t *testing.T) { want: SuccessMarker() + " Snapshot saved to pod:minimal-pod", wantOK: true, }, + + // snapshot load events + { + name: "snapshot loaded with services", + event: SnapshotLoadedEvent{Source: "./my-baseline.zip", Services: []string{"s3", "dynamodb"}}, + want: SuccessMarker() + " Snapshot loaded from ./my-baseline.zip\n• Services: s3, dynamodb", + wantOK: true, + }, + { + name: "snapshot loaded no services", + event: SnapshotLoadedEvent{Source: "./snap.zip"}, + want: SuccessMarker() + " Snapshot loaded from ./snap.zip", + wantOK: true, + }, + { + name: "pod snapshot loaded with services", + event: SnapshotLoadedEvent{Source: "pod:my-baseline", Services: []string{"s3", "lambda"}}, + want: SuccessMarker() + " Snapshot loaded from pod:my-baseline\n• Services: s3, lambda", + wantOK: true, + }, + { + name: "pod snapshot loaded no services", + event: SnapshotLoadedEvent{Source: "pod:empty-pod"}, + want: SuccessMarker() + " Snapshot loaded from pod:empty-pod", + wantOK: true, + }, } for _, tt := range tests { diff --git a/internal/snapshot/destination.go b/internal/snapshot/destination.go index 17b2e36f..9c8590dc 100644 --- a/internal/snapshot/destination.go +++ b/internal/snapshot/destination.go @@ -52,6 +52,67 @@ func displayPath(abs, cwd, home string) string { return abs } +// ParseSource resolves a user-supplied source REF for loading a snapshot. +// Unlike ParseDestination it never auto-generates a name: REF is required. +// For local paths, the file must exist; if no extension is given, .zip is tried as a fallback. +func ParseSource(ref string) (Destination, error) { + if ref == "" { + return Destination{}, fmt.Errorf("REF is required for snapshot load") + } + + lower := strings.ToLower(ref) + switch { + case strings.HasPrefix(lower, "pod://"): + podName := ref[len("pod://"):] + return Destination{}, fmt.Errorf("'%s' is not a valid reference. Aliases use a single colon. Did you mean:\npod:%s", ref, podName) + case strings.HasPrefix(lower, "pod:"): + podName := ref[len("pod:"):] + if !validPodName.MatchString(podName) { + return Destination{}, fmt.Errorf("invalid pod name %q: use letters, digits, and hyphens only, starting with a letter or digit", podName) + } + return Destination{Kind: KindPod, Value: podName}, nil + case strings.HasPrefix(lower, "s3://"), + strings.HasPrefix(lower, "oras://"): + return Destination{}, ErrRemoteNotSupported + case strings.Contains(lower, "://"): + scheme, _, _ := strings.Cut(ref, "://") + return Destination{}, fmt.Errorf("%w: %q", ErrUnknownScheme, scheme+"://") + } + + if ref == "~" || strings.HasPrefix(ref, "~/") || strings.HasPrefix(ref, `~\`) { + home, err := os.UserHomeDir() + if err != nil { + return Destination{}, fmt.Errorf("resolve home directory: %w", err) + } + ref = filepath.Join(home, strings.TrimLeft(ref[1:], `/\`)) + } + + abs, err := filepath.Abs(ref) + if err != nil { + return Destination{}, fmt.Errorf("resolve path: %w", err) + } + + // Try the path as-is first, then with .zip appended as a fallback for bare + // names (e.g. "my-snapshot" → "my-snapshot.zip" since that is what save produces). + resolved, err := resolveSourcePath(abs) + if err != nil { + return Destination{}, err + } + return Destination{Kind: KindLocal, Value: resolved}, nil +} + +// resolveSourcePath returns the first existing path among: abs as-is, then abs+".zip". +func resolveSourcePath(abs string) (string, error) { + if _, err := os.Stat(abs); err == nil { + return abs, nil + } + withZip := abs + ".zip" + if _, err := os.Stat(withZip); err == nil { + return withZip, nil + } + return "", fmt.Errorf("snapshot file not found: %q (also tried %q)", abs, withZip) +} + // ParseDestination resolves a user-supplied destination to a local path (KindLocal) or validated pod name (KindPod). func ParseDestination(dest string, now time.Time) (Destination, error) { if dest == "" { diff --git a/internal/snapshot/destination_test.go b/internal/snapshot/destination_test.go index ae11b1a9..884a9d83 100644 --- a/internal/snapshot/destination_test.go +++ b/internal/snapshot/destination_test.go @@ -16,6 +16,188 @@ import ( +func TestParseSource(t *testing.T) { + t.Parallel() + wd, err := os.Getwd() + require.NoError(t, err) + home, err := os.UserHomeDir() + require.NoError(t, err) + + dir := t.TempDir() + existingZip := filepath.Join(dir, "snap.zip") + require.NoError(t, os.WriteFile(existingZip, []byte("data"), 0600)) + existingBare := filepath.Join(dir, "bare") // no extension — snap.zip fallback exists + require.NoError(t, os.WriteFile(existingBare+".zip", []byte("data"), 0600)) + existingNoExt := filepath.Join(dir, "noext") // no extension, no .zip counterpart either + require.NoError(t, os.WriteFile(existingNoExt, []byte("data"), 0600)) + + type testCase struct { + name string + input string + wantKind snapshot.DestinationKind + wantPath string + wantPodName string + wantErr string + wantRemoteErr bool + wantSchemeErr bool + } + + tests := []testCase{ + // --- empty ref --- + { + name: "empty ref", + input: "", + wantErr: "REF is required", + }, + + // --- local paths (file must exist) --- + { + name: "explicit .zip path", + input: existingZip, + wantKind: snapshot.KindLocal, + wantPath: existingZip, + }, + { + name: "bare name resolves to .zip fallback", + input: existingBare, + wantKind: snapshot.KindLocal, + wantPath: existingBare + ".zip", + }, + { + name: "file without .zip extension resolves as-is", + input: existingNoExt, + wantKind: snapshot.KindLocal, + wantPath: existingNoExt, + }, + { + name: "nonexistent file returns error", + input: filepath.Join(dir, "missing.zip"), + wantErr: "snapshot file not found", + }, + { + name: "nonexistent bare name returns error", + input: filepath.Join(dir, "ghost"), + wantErr: "snapshot file not found", + }, + { + name: "relative path resolved via cwd", + input: ".", + wantKind: snapshot.KindLocal, + wantPath: wd, + }, + + // --- tilde expansion --- + { + name: "tilde expands to home", + input: "~/.", + wantKind: snapshot.KindLocal, + wantPath: home, + }, + + // --- pod sources --- + { + name: "pod: prefix", + input: "pod:my-baseline", + wantKind: snapshot.KindPod, + wantPodName: "my-baseline", + }, + { + name: "Pod: case insensitive", + input: "Pod:my-baseline", + wantKind: snapshot.KindPod, + wantPodName: "my-baseline", + }, + { + name: "pod:// rejected with did-you-mean hint", + input: "pod://my-baseline", + wantErr: "not a valid reference. Aliases use a single colon. Did you mean:\npod:my-baseline", + }, + { + name: "pod: empty name", + input: "pod:", + wantErr: "invalid pod name", + }, + { + name: "pod: leading hyphen", + input: "pod:-bad", + wantErr: "invalid pod name", + }, + + // --- remote schemes --- + { + name: "s3:// not supported", + input: "s3://bucket/key", + wantRemoteErr: true, + }, + { + name: "oras:// not supported", + input: "oras://registry/image", + wantRemoteErr: true, + }, + { + name: "unknown scheme", + input: "gcs://bucket/key", + wantSchemeErr: true, + }, + } + + if runtime.GOOS == "windows" { + tests = append(tests, + testCase{ + name: "windows tilde backslash", + input: `~\` + filepath.Base(existingZip), + wantKind: snapshot.KindLocal, + // The resolved path won't equal existingZip (different dir), so just + // check it doesn't error; path matching is covered by the cross-platform cases. + wantErr: "snapshot file not found", + }, + testCase{ + name: "windows abs backslash to existing zip", + input: existingZip, + wantKind: snapshot.KindLocal, + wantPath: existingZip, + }, + testCase{ + name: "windows abs forward-slash to existing zip", + input: strings.ReplaceAll(existingZip, `\`, `/`), + wantKind: snapshot.KindLocal, + wantPath: existingZip, + }, + ) + } + + for _, tc := range tests { + name := tc.input + if tc.name != "" { + name = tc.name + } + t.Run(name, func(t *testing.T) { + t.Parallel() + got, err := snapshot.ParseSource(tc.input) + if tc.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErr) + return + } + if tc.wantRemoteErr { + require.ErrorIs(t, err, snapshot.ErrRemoteNotSupported) + return + } + if tc.wantSchemeErr { + require.ErrorIs(t, err, snapshot.ErrUnknownScheme) + return + } + require.NoError(t, err) + assert.Equal(t, tc.wantKind, got.Kind) + if tc.wantPodName != "" { + assert.Equal(t, tc.wantPodName, got.Value) + } else { + assert.Equal(t, tc.wantPath, got.Value) + } + }) + } +} + func TestDisplayPath(t *testing.T) { t.Parallel() diff --git a/internal/snapshot/load.go b/internal/snapshot/load.go new file mode 100644 index 00000000..5ab7e4f6 --- /dev/null +++ b/internal/snapshot/load.go @@ -0,0 +1,142 @@ +package snapshot + +//go:generate mockgen -source=load.go -destination=mock_load_client_test.go -package=snapshot_test + +import ( + "context" + "fmt" + "io" + "os" + + "github.com/localstack/lstk/internal/config" + "github.com/localstack/lstk/internal/container" + "github.com/localstack/lstk/internal/output" + "github.com/localstack/lstk/internal/runtime" +) + +const ( + MergeStrategyAccountRegion = "account-region-merge" + MergeStrategyOverwrite = "overwrite" + MergeStrategyService = "service-merge" +) + +func ValidateMergeStrategy(strategy string) error { + switch strategy { + case MergeStrategyAccountRegion, MergeStrategyOverwrite, MergeStrategyService: + return nil + default: + return fmt.Errorf("unknown merge strategy %q: use overwrite, account-region-merge, or service-merge", strategy) + } +} + +// Starter is called to auto-start the emulator when none is running. +type Starter func(ctx context.Context, sink output.Sink) error + +// LocalLoadClient is satisfied by aws.Client. +type LocalLoadClient interface { + // ImportState posts a zip to /_localstack/pods[?merge=strategy] and streams + // the NDJSON response. strategy is passed as-is; empty means server default. + ImportState(ctx context.Context, host string, src io.Reader, strategy string) error + // ResetState wipes all running state via POST /_localstack/state/reset. + // Used to implement overwrite client-side before importing. + ResetState(ctx context.Context, host string) error +} + +// PodLoader is satisfied by aws.Client. +type PodLoader interface { + // LoadPodSnapshot issues PUT /_localstack/pods/{name}?merge=strategy and + // streams the NDJSON response. + LoadPodSnapshot(ctx context.Context, host, podName, authToken, strategy string) ([]string, error) +} + +// load is the shared entry point for both LoadLocal and LoadPod. +// It checks runtime health, auto-starts the emulator if needed, then runs do(). +func load(ctx context.Context, rt runtime.Runtime, containers []config.ContainerConfig, sink output.Sink, starter Starter, spinnerText string, onSuccess func(), do func() error) (retErr error) { + if err := rt.IsHealthy(ctx); err != nil { + rt.EmitUnhealthyError(sink, err) + return output.NewSilentError(fmt.Errorf("runtime not healthy: %w", err)) + } + + runningContainers, err := container.RunningEmulators(ctx, rt, containers) + if err != nil { + return fmt.Errorf("checking emulator status: %w", err) + } + + if len(runningContainers) == 0 { + if starter == nil { + sink.Emit(output.ErrorEvent{ + Title: "LocalStack is not running", + Actions: []output.ErrorAction{ + {Label: "Start LocalStack:", Value: "lstk"}, + {Label: "See help:", Value: "lstk -h"}, + }, + }) + return output.NewSilentError(fmt.Errorf("LocalStack is not running")) + } + if err := starter(ctx, sink); err != nil { + return err + } + } + + sink.Emit(output.SpinnerStart(spinnerText)) + defer func() { + sink.Emit(output.SpinnerStop()) + if retErr == nil { + onSuccess() + } + }() + + return do() +} + +func LoadLocal(ctx context.Context, rt runtime.Runtime, containers []config.ContainerConfig, client LocalLoadClient, host, src, strategy string, starter Starter, sink output.Sink) error { + cwd, _ := os.Getwd() + home, _ := os.UserHomeDir() + + return load(ctx, rt, containers, sink, starter, + "Loading snapshot...", + func() { + sink.Emit(output.SnapshotLoadedEvent{Source: displayPath(src, cwd, home)}) + }, + func() error { + // overwrite is handled client-side: reset running state, then import + // with the server default (account-region-merge on clean state = overwrite). + if strategy == MergeStrategyOverwrite { + if err := client.ResetState(ctx, host); err != nil { + return fmt.Errorf("reset state: %w", err) + } + strategy = "" + } + + f, err := os.Open(src) + if err != nil { + return fmt.Errorf("open snapshot: %w", err) + } + defer func() { _ = f.Close() }() + + return client.ImportState(ctx, host, f, strategy) + }, + ) +} + +func LoadPod(ctx context.Context, rt runtime.Runtime, containers []config.ContainerConfig, loader PodLoader, host, podName, authToken, strategy string, starter Starter, sink output.Sink) error { + if authToken == "" { + return fmt.Errorf("pod snapshots require authentication — set LOCALSTACK_AUTH_TOKEN or run %q", "lstk login") + } + + var services []string + return load(ctx, rt, containers, sink, starter, + fmt.Sprintf("Loading snapshot from pod %q...", podName), + func() { + sink.Emit(output.SnapshotLoadedEvent{ + Source: "pod:" + podName, + Services: services, + }) + }, + func() error { + var err error + services, err = loader.LoadPodSnapshot(ctx, host, podName, authToken, strategy) + return err + }, + ) +} diff --git a/internal/snapshot/load_test.go b/internal/snapshot/load_test.go new file mode 100644 index 00000000..24279ead --- /dev/null +++ b/internal/snapshot/load_test.go @@ -0,0 +1,277 @@ +package snapshot_test + +import ( + "context" + "fmt" + "io" + "os" + "path/filepath" + "testing" + + "github.com/localstack/lstk/internal/output" + "github.com/localstack/lstk/internal/runtime" + "github.com/localstack/lstk/internal/snapshot" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +func nopStarter(ctx context.Context, sink output.Sink) error { return nil } + +func mockLocalClientReturning(t *testing.T, importErr error) *MockLocalLoadClient { + t.Helper() + ctrl := gomock.NewController(t) + m := NewMockLocalLoadClient(ctrl) + if importErr == nil { + m.EXPECT().ImportState(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + } else { + m.EXPECT().ImportState(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(importErr).AnyTimes() + } + return m +} + +func writeSnapshotFile(t *testing.T, content string) string { + t.Helper() + dir := t.TempDir() + path := filepath.Join(dir, "snap.zip") + require.NoError(t, os.WriteFile(path, []byte(content), 0600)) + return path +} + +func TestLoadLocal_Success(t *testing.T) { + t.Parallel() + src := writeSnapshotFile(t, "ZIP_DATA") + client := mockLocalClientReturning(t, nil) + sink, getEvents := captureEvents(t) + + err := snapshot.LoadLocal(context.Background(), healthyRunningMock(t), awsContainers, client, "", src, "", nopStarter, sink) + require.NoError(t, err) + + events := getEvents() + var spinnerStarted, spinnerStopped, loaded bool + for _, e := range events { + switch ev := e.(type) { + case output.SpinnerEvent: + if ev.Active { + spinnerStarted = true + } else { + spinnerStopped = true + } + case output.SnapshotLoadedEvent: + loaded = true + assert.NotEmpty(t, ev.Source) + } + } + assert.True(t, spinnerStarted, "spinner should have started") + assert.True(t, spinnerStopped, "spinner should have stopped") + assert.True(t, loaded, "SnapshotLoadedEvent should have been emitted") +} + +func TestLoadLocal_OverwriteStrategy(t *testing.T) { + t.Parallel() + src := writeSnapshotFile(t, "ZIP_DATA") + + ctrl := gomock.NewController(t) + client := NewMockLocalLoadClient(ctrl) + client.EXPECT().ResetState(gomock.Any(), gomock.Any()).Return(nil) + client.EXPECT().ImportState(gomock.Any(), gomock.Any(), gomock.Any(), "").Return(nil) + + sink := output.NewPlainSink(io.Discard) + err := snapshot.LoadLocal(context.Background(), healthyRunningMock(t), awsContainers, client, "", src, snapshot.MergeStrategyOverwrite, nopStarter, sink) + require.NoError(t, err) +} + +func TestLoadLocal_ResetErrorAbortsImport(t *testing.T) { + t.Parallel() + src := writeSnapshotFile(t, "ZIP_DATA") + + ctrl := gomock.NewController(t) + client := NewMockLocalLoadClient(ctrl) + client.EXPECT().ResetState(gomock.Any(), gomock.Any()).Return(fmt.Errorf("reset failed")) + + sink := output.NewPlainSink(io.Discard) + err := snapshot.LoadLocal(context.Background(), healthyRunningMock(t), awsContainers, client, "", src, snapshot.MergeStrategyOverwrite, nopStarter, sink) + require.Error(t, err) + assert.Contains(t, err.Error(), "reset failed") +} + +func TestLoadLocal_ImportError(t *testing.T) { + t.Parallel() + src := writeSnapshotFile(t, "ZIP_DATA") + client := mockLocalClientReturning(t, fmt.Errorf("incompatible version")) + sink := output.NewPlainSink(io.Discard) + + err := snapshot.LoadLocal(context.Background(), healthyRunningMock(t), awsContainers, client, "", src, "", nopStarter, sink) + require.Error(t, err) + assert.Contains(t, err.Error(), "incompatible version") +} + +func TestLoadLocal_FileNotFound(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + client := NewMockLocalLoadClient(ctrl) + sink := output.NewPlainSink(io.Discard) + + err := snapshot.LoadLocal(context.Background(), healthyRunningMock(t), awsContainers, client, "", "/no/such/file.zip", "", nopStarter, sink) + require.Error(t, err) +} + +func TestLoadLocal_EmulatorNotRunning_AutoStarts(t *testing.T) { + t.Parallel() + src := writeSnapshotFile(t, "ZIP_DATA") + + ctrl := gomock.NewController(t) + mockRT := runtime.NewMockRuntime(ctrl) + mockRT.EXPECT().IsHealthy(gomock.Any()).Return(nil) + mockRT.EXPECT().IsRunning(gomock.Any(), "localstack-aws").Return(false, nil) + mockRT.EXPECT().FindRunningByImage(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) + + client := NewMockLocalLoadClient(ctrl) + client.EXPECT().ImportState(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + + var starterCalled bool + starter := func(ctx context.Context, sink output.Sink) error { + starterCalled = true + return nil + } + + sink := output.NewPlainSink(io.Discard) + err := snapshot.LoadLocal(context.Background(), mockRT, awsContainers, client, "", src, "", starter, sink) + require.NoError(t, err) + assert.True(t, starterCalled, "starter should have been called when emulator is not running") +} + +func TestLoadLocal_EmulatorNotRunning_NoStarter(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + mockRT := runtime.NewMockRuntime(ctrl) + mockRT.EXPECT().IsHealthy(gomock.Any()).Return(nil) + mockRT.EXPECT().IsRunning(gomock.Any(), "localstack-aws").Return(false, nil) + mockRT.EXPECT().FindRunningByImage(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) + + client := NewMockLocalLoadClient(ctrl) + src := writeSnapshotFile(t, "ZIP_DATA") + sink, getEvents := captureEvents(t) + + err := snapshot.LoadLocal(context.Background(), mockRT, awsContainers, client, "", src, "", nil, sink) + require.Error(t, err) + assert.True(t, output.IsSilent(err)) + + var gotErrorEvent bool + for _, e := range getEvents() { + if ev, ok := e.(output.ErrorEvent); ok { + gotErrorEvent = true + assert.Contains(t, ev.Title, "not running") + } + } + assert.True(t, gotErrorEvent, "ErrorEvent should have been emitted") +} + +func TestLoadLocal_UnhealthyRuntime(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + mockRT := runtime.NewMockRuntime(ctrl) + mockRT.EXPECT().IsHealthy(gomock.Any()).Return(fmt.Errorf("docker unavailable")) + mockRT.EXPECT().EmitUnhealthyError(gomock.Any(), gomock.Any()) + + client := NewMockLocalLoadClient(ctrl) + src := writeSnapshotFile(t, "ZIP_DATA") + sink := output.NewPlainSink(io.Discard) + + err := snapshot.LoadLocal(context.Background(), mockRT, awsContainers, client, "", src, "", nopStarter, sink) + require.Error(t, err) + assert.True(t, output.IsSilent(err)) +} + +func TestLoadPod_Success(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + loader := NewMockPodLoader(ctrl) + loader.EXPECT().LoadPodSnapshot(gomock.Any(), gomock.Any(), "my-baseline", "test-token", ""). + Return([]string{"s3", "dynamodb"}, nil) + + sink, getEvents := captureEvents(t) + err := snapshot.LoadPod(context.Background(), healthyRunningMock(t), awsContainers, loader, "", "my-baseline", "test-token", "", nopStarter, sink) + require.NoError(t, err) + + events := getEvents() + var spinnerStarted, spinnerStopped bool + var loaded *output.SnapshotLoadedEvent + for _, e := range events { + switch ev := e.(type) { + case output.SpinnerEvent: + if ev.Active { + spinnerStarted = true + } else { + spinnerStopped = true + } + case output.SnapshotLoadedEvent: + loaded = &ev + } + } + assert.True(t, spinnerStarted, "spinner should have started") + assert.True(t, spinnerStopped, "spinner should have stopped") + require.NotNil(t, loaded, "SnapshotLoadedEvent should have been emitted") + assert.Equal(t, "pod:my-baseline", loaded.Source) + assert.Equal(t, []string{"s3", "dynamodb"}, loaded.Services) +} + +func TestLoadPod_NoAuthToken(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + loader := NewMockPodLoader(ctrl) + sink := output.NewPlainSink(io.Discard) + + err := snapshot.LoadPod(context.Background(), runtime.NewMockRuntime(ctrl), awsContainers, loader, "", "my-baseline", "", "", nopStarter, sink) + require.Error(t, err) + assert.Contains(t, err.Error(), "authentication") +} + +func TestLoadPod_LoaderError(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + loader := NewMockPodLoader(ctrl) + loader.EXPECT().LoadPodSnapshot(gomock.Any(), gomock.Any(), "my-baseline", "test-token", gomock.Any()). + Return(nil, fmt.Errorf("platform unreachable")) + + sink, _ := captureEvents(t) + err := snapshot.LoadPod(context.Background(), healthyRunningMock(t), awsContainers, loader, "", "my-baseline", "test-token", "", nopStarter, sink) + require.Error(t, err) + assert.Contains(t, err.Error(), "platform unreachable") +} + +func TestLoadPod_WithMergeStrategy(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + loader := NewMockPodLoader(ctrl) + loader.EXPECT().LoadPodSnapshot(gomock.Any(), gomock.Any(), "my-pod", "tok", snapshot.MergeStrategyService). + Return([]string{"s3"}, nil) + + sink := output.NewPlainSink(io.Discard) + err := snapshot.LoadPod(context.Background(), healthyRunningMock(t), awsContainers, loader, "", "my-pod", "tok", snapshot.MergeStrategyService, nopStarter, sink) + require.NoError(t, err) +} + +func TestLoadPod_EmulatorNotRunning_AutoStarts(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + mockRT := runtime.NewMockRuntime(ctrl) + mockRT.EXPECT().IsHealthy(gomock.Any()).Return(nil) + mockRT.EXPECT().IsRunning(gomock.Any(), "localstack-aws").Return(false, nil) + mockRT.EXPECT().FindRunningByImage(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) + + loader := NewMockPodLoader(ctrl) + loader.EXPECT().LoadPodSnapshot(gomock.Any(), gomock.Any(), "my-pod", "tok", gomock.Any()). + Return([]string{"s3"}, nil) + + var starterCalled bool + starter := func(ctx context.Context, sink output.Sink) error { + starterCalled = true + return nil + } + + sink := output.NewPlainSink(io.Discard) + err := snapshot.LoadPod(context.Background(), mockRT, awsContainers, loader, "", "my-pod", "tok", "", starter, sink) + require.NoError(t, err) + assert.True(t, starterCalled, "starter should have been called when emulator is not running") +} diff --git a/internal/snapshot/mock_load_client_test.go b/internal/snapshot/mock_load_client_test.go new file mode 100644 index 00000000..cc4ea432 --- /dev/null +++ b/internal/snapshot/mock_load_client_test.go @@ -0,0 +1,109 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: load.go +// +// Generated by this command: +// +// mockgen -source=load.go -destination=mock_load_client_test.go -package=snapshot_test +// + +// Package snapshot_test is a generated GoMock package. +package snapshot_test + +import ( + context "context" + io "io" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockLocalLoadClient is a mock of LocalLoadClient interface. +type MockLocalLoadClient struct { + ctrl *gomock.Controller + recorder *MockLocalLoadClientMockRecorder + isgomock struct{} +} + +// MockLocalLoadClientMockRecorder is the mock recorder for MockLocalLoadClient. +type MockLocalLoadClientMockRecorder struct { + mock *MockLocalLoadClient +} + +// NewMockLocalLoadClient creates a new mock instance. +func NewMockLocalLoadClient(ctrl *gomock.Controller) *MockLocalLoadClient { + mock := &MockLocalLoadClient{ctrl: ctrl} + mock.recorder = &MockLocalLoadClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLocalLoadClient) EXPECT() *MockLocalLoadClientMockRecorder { + return m.recorder +} + +// ImportState mocks base method. +func (m *MockLocalLoadClient) ImportState(ctx context.Context, host string, src io.Reader, strategy string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ImportState", ctx, host, src, strategy) + ret0, _ := ret[0].(error) + return ret0 +} + +// ImportState indicates an expected call of ImportState. +func (mr *MockLocalLoadClientMockRecorder) ImportState(ctx, host, src, strategy any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ImportState", reflect.TypeOf((*MockLocalLoadClient)(nil).ImportState), ctx, host, src, strategy) +} + +// ResetState mocks base method. +func (m *MockLocalLoadClient) ResetState(ctx context.Context, host string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ResetState", ctx, host) + ret0, _ := ret[0].(error) + return ret0 +} + +// ResetState indicates an expected call of ResetState. +func (mr *MockLocalLoadClientMockRecorder) ResetState(ctx, host any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetState", reflect.TypeOf((*MockLocalLoadClient)(nil).ResetState), ctx, host) +} + +// MockPodLoader is a mock of PodLoader interface. +type MockPodLoader struct { + ctrl *gomock.Controller + recorder *MockPodLoaderMockRecorder + isgomock struct{} +} + +// MockPodLoaderMockRecorder is the mock recorder for MockPodLoader. +type MockPodLoaderMockRecorder struct { + mock *MockPodLoader +} + +// NewMockPodLoader creates a new mock instance. +func NewMockPodLoader(ctrl *gomock.Controller) *MockPodLoader { + mock := &MockPodLoader{ctrl: ctrl} + mock.recorder = &MockPodLoaderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPodLoader) EXPECT() *MockPodLoaderMockRecorder { + return m.recorder +} + +// LoadPodSnapshot mocks base method. +func (m *MockPodLoader) LoadPodSnapshot(ctx context.Context, host, podName, authToken, strategy string) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoadPodSnapshot", ctx, host, podName, authToken, strategy) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LoadPodSnapshot indicates an expected call of LoadPodSnapshot. +func (mr *MockPodLoaderMockRecorder) LoadPodSnapshot(ctx, host, podName, authToken, strategy any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadPodSnapshot", reflect.TypeOf((*MockPodLoader)(nil).LoadPodSnapshot), ctx, host, podName, authToken, strategy) +} diff --git a/internal/ui/run_snapshot_load.go b/internal/ui/run_snapshot_load.go new file mode 100644 index 00000000..f7e857c1 --- /dev/null +++ b/internal/ui/run_snapshot_load.go @@ -0,0 +1,27 @@ +package ui + +import ( + "context" + + "github.com/localstack/lstk/internal/config" + "github.com/localstack/lstk/internal/output" + "github.com/localstack/lstk/internal/runtime" + "github.com/localstack/lstk/internal/snapshot" +) + +// SnapshotLoadClient is satisfied by aws.Client. +type SnapshotLoadClient interface { + snapshot.LocalLoadClient + snapshot.PodLoader +} + +func RunSnapshotLoad(parentCtx context.Context, rt runtime.Runtime, containers []config.ContainerConfig, client SnapshotLoadClient, host string, src snapshot.Destination, authToken, strategy string, starter snapshot.Starter) error { + return runWithTUI(parentCtx, withoutHeader(), func(ctx context.Context, sink output.Sink) error { + switch src.Kind { + case snapshot.KindPod: + return snapshot.LoadPod(ctx, rt, containers, client, host, src.Value, authToken, strategy, starter, sink) + default: + return snapshot.LoadLocal(ctx, rt, containers, client, host, src.Value, strategy, starter, sink) + } + }) +}