diff --git a/helm/bundles/cortex-nova/values.yaml b/helm/bundles/cortex-nova/values.yaml index 452f11158..a6d777417 100644 --- a/helm/bundles/cortex-nova/values.yaml +++ b/helm/bundles/cortex-nova/values.yaml @@ -186,9 +186,12 @@ cortex-scheduling-controllers: # Used when maxVMsToProcess limits processing, allows faster catch-up and for the first reconcile shortReconcileInterval: 1m # Number of max VMs to process in one periodic reconciliation loop - maxVMsToProcess: 25 + maxVMsToProcess: 50 + # How often to rotate VM selection offset when maxVMsToProcess limits processing + # Every N reconcile cycles, the offset rotates to process different VMs + vmSelectionRotationInterval: 3 # Minimum successful reservations to use short interval - minSuccessForShortInterval: 1 + minSuccessForShortInterval: 0 # Maximum failures allowed to still use short interval maxFailuresForShortInterval: 99 # If true, uses hypervisor CRD as source of truth for VM location instead of postgres diff --git a/internal/scheduling/reservations/failover/controller.go b/internal/scheduling/reservations/failover/controller.go index e31a4aa1c..4c99f4b2a 100644 --- a/internal/scheduling/reservations/failover/controller.go +++ b/internal/scheduling/reservations/failover/controller.go @@ -228,6 +228,7 @@ func (c *FailoverReservationController) validateReservation(ctx context.Context, // reconcileSummary holds statistics from the reconciliation cycle. type reconcileSummary struct { + vmsMissingFailover int vmsProcessed int reservationsNeeded int totalReused int @@ -268,6 +269,7 @@ func (c *FailoverReservationController) ReconcilePeriodic(ctx context.Context) ( } logger.V(1).Info("found VMs from source", "count", len(vms)) + // todo: vms are vms from all AZs, we should consdier processing them by AZ (sequencial or in parallel) but not mixing them together // List only failover reservations using label selector var reservationList v1alpha1.ReservationList if err := c.List(ctx, &reservationList, client.MatchingLabels{ @@ -313,6 +315,7 @@ func (c *FailoverReservationController) ReconcilePeriodic(ctx context.Context) ( // 6. Create and assign reservations for VMs that need them assignSummary, hitMaxVMsLimit := c.reconcileCreateAndAssignReservations(ctx, vms, failoverReservations, allHypervisors) + summary.vmsMissingFailover = assignSummary.vmsMissingFailover summary.vmsProcessed = assignSummary.vmsProcessed summary.reservationsNeeded = assignSummary.reservationsNeeded summary.totalReused = assignSummary.totalReused @@ -332,6 +335,9 @@ func (c *FailoverReservationController) ReconcilePeriodic(ctx context.Context) ( "reconcileCount", c.reconcileCount, "duration", duration.Round(time.Millisecond), "requeueAfter", requeueAfter, + "totalVMs", len(vms), + "totalReservations", len(failoverReservations), + "vmsMissingFailover", summary.vmsMissingFailover, "vmsProcessed", summary.vmsProcessed, "reservationsNeeded", summary.reservationsNeeded, "reused", summary.totalReused, @@ -557,11 +563,12 @@ func (c *FailoverReservationController) reconcileCreateAndAssignReservations( vmsMissingFailover := c.calculateVMsMissingFailover(ctx, vms, failoverReservations) logger.V(1).Info("VMs missing failover reservations", "count", len(vmsMissingFailover)) + totalVMsMissingFailover := len(vmsMissingFailover) vmsMissingFailover, hitMaxVMsLimit := c.selectVMsToProcess(ctx, vmsMissingFailover, c.Config.MaxVMsToProcess) logger.V(1).Info("found hypervisors and vm missing failover reservation", "countHypervisors", len(allHypervisors), - "countVMsMissingFailover", len(vmsMissingFailover)) + "countVMsMissingFailover", totalVMsMissingFailover) totalReservationsNeeded := 0 for _, need := range vmsMissingFailover { @@ -649,6 +656,7 @@ func (c *FailoverReservationController) reconcileCreateAndAssignReservations( } return reconcileSummary{ + vmsMissingFailover: totalVMsMissingFailover, vmsProcessed: len(vmsMissingFailover), reservationsNeeded: totalReservationsNeeded, totalReused: totalReused, diff --git a/tools/visualize-reservations/main.go b/tools/visualize-reservations/main.go index 6a21d551c..9b5880be5 100644 --- a/tools/visualize-reservations/main.go +++ b/tools/visualize-reservations/main.go @@ -9,19 +9,26 @@ // // Flags: // +// --config=path Path to JSON config file (CLI flags override config values) // --sort=vm|vm-host|res-host Sort VMs by UUID, VM host, or reservation host // --postgres-secret=name Name of the kubernetes secret containing postgres credentials (default: cortex-nova-postgres) // --namespace=ns Namespace of the postgres secret (default: default) // --postgres-host=host Override postgres host (useful with port-forward, e.g., localhost) // --postgres-port=port Override postgres port (useful with port-forward, e.g., 5432) +// --postgres-port-forward Automatically run kubectl port-forward for postgres +// --postgres-port-forward-service=name Service name for port-forward (defaults to secret host) +// --postgres-port-forward-local-port=port Local port for port-forward (defaults to postgres port) +// --postgres-port-forward-remote-port=port Remote port for port-forward (defaults to postgres port) // --views=view1,view2,... Comma-separated list of views to show (default: all) // Available views: hypervisors, vms, reservations, summary, // hypervisor-summary, validation, stale, without-res, not-in-db, by-host // --hide=view1,view2,... Comma-separated list of views to hide (applied after --views) // --filter-name=pattern Filter hypervisors by name (substring match) // --filter-trait=trait Filter hypervisors by trait (e.g., CUSTOM_HANA_EXCLUSIVE_HOST) -// --hypervisor-context=name Kubernetes context for reading Hypervisors (default: current context) -// --reservation-context=name Kubernetes context for reading Reservations (default: current context) +// --hypervisor-context=name Kubernetes context for reading Hypervisors (default: current context, backward compatible) +// --hypervisor-contexts=ctx1,ctx2,... Kubernetes contexts for reading Hypervisors +// --reservation-context=name Kubernetes context for reading Reservations (default: current context, backward compatible) +// --reservation-contexts=ctx1,ctx2,... Kubernetes contexts for reading Reservations // --postgres-context=name Kubernetes context for reading postgres secret (default: current context) // // To connect to postgres when running locally, use kubectl port-forward: @@ -34,11 +41,15 @@ import ( "context" "database/sql" "encoding/json" + "errors" "flag" "fmt" + "net" "os" + "os/exec" "sort" "strings" + "time" "github.com/cobaltcore-dev/cortex/api/v1alpha1" hv1 "github.com/cobaltcore-dev/openstack-hypervisor-operator/api/v1" @@ -116,6 +127,27 @@ type hypervisorSummary struct { // viewSet tracks which views should be displayed type viewSet map[string]bool +type toolConfig struct { + Sort string `json:"sort"` + PostgresSecret string `json:"postgres_secret"` + Namespace string `json:"namespace"` + PostgresHost string `json:"postgres_host"` + PostgresPort string `json:"postgres_port"` + Views string `json:"views"` + Hide string `json:"hide"` + FilterName string `json:"filter_name"` + FilterTrait string `json:"filter_trait"` + HypervisorContext string `json:"hypervisor_context"` + HypervisorContexts []string `json:"hypervisor_contexts"` + ReservationContext string `json:"reservation_context"` + ReservationContexts []string `json:"reservation_contexts"` + PostgresContext string `json:"postgres_context"` + PostgresPortForward bool `json:"postgres_port_forward"` + PostgresPortForwardService string `json:"postgres_port_forward_service"` + PostgresPortForwardLocal string `json:"postgres_port_forward_local_port"` + PostgresPortForwardRemote string `json:"postgres_port_forward_remote_port"` +} + // Available view names const ( viewHypervisors = "hypervisors" @@ -182,6 +214,117 @@ func applyHideViews(views viewSet, hideFlag string) { } } +func loadConfig(configPath string) (toolConfig, error) { + if configPath == "" { + return toolConfig{}, nil + } + + content, err := os.ReadFile(configPath) + if err != nil { + return toolConfig{}, fmt.Errorf("reading config file %q: %w", configPath, err) + } + + var cfg toolConfig + if err := json.Unmarshal(content, &cfg); err != nil { + return toolConfig{}, fmt.Errorf("parsing config file %q as JSON: %w", configPath, err) + } + + return cfg, nil +} + +func splitCSV(value string) []string { + if value == "" { + return nil + } + parts := strings.Split(value, ",") + out := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + out = append(out, p) + } + } + return out +} + +func uniqueStrings(in []string) []string { + if len(in) == 0 { + return nil + } + seen := make(map[string]bool, len(in)) + out := make([]string, 0, len(in)) + for _, v := range in { + if seen[v] { + continue + } + seen[v] = true + out = append(out, v) + } + return out +} + +func contextDisplayName(ctx string) string { + if ctx == "" { + return "current" + } + return ctx +} + +func annotateReservationName(name, ctx string, includeContext bool) string { + if !includeContext { + return name + } + return fmt.Sprintf("%s@%s", name, contextDisplayName(ctx)) +} + +func startPortForward(namespace, contextName, service, localPort, remotePort string) (func(), error) { + if service == "" { + return nil, errors.New("port-forward service is empty") + } + if localPort == "" || remotePort == "" { + return nil, errors.New("port-forward local/remote port must be set") + } + + resource := service + if !strings.Contains(resource, "/") { + resource = "svc/" + resource + } + + args := []string{"port-forward", resource, fmt.Sprintf("%s:%s", localPort, remotePort), "-n", namespace} + if contextName != "" { + args = append(args, "--context", contextName) + } + + cmd := exec.Command("kubectl", args...) + cmd.Stdout = os.Stderr + cmd.Stderr = os.Stderr + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("starting kubectl port-forward: %w", err) + } + + cleanup := func() { + _ = cmd.Process.Kill() //nolint:errcheck // best-effort cleanup + _, _ = cmd.Process.Wait() //nolint:errcheck // best-effort cleanup + } + + readyDeadline := time.Now().Add(5 * time.Second) + for { + conn, err := net.DialTimeout("tcp", net.JoinHostPort("127.0.0.1", localPort), 200*time.Millisecond) + if err == nil { + _ = conn.Close() + break + } + if time.Now().After(readyDeadline) { + cleanup() + return nil, fmt.Errorf("timed out waiting for port-forward on localhost:%s", localPort) + } + time.Sleep(150 * time.Millisecond) + } + + return cleanup, nil +} + // getClientForContext creates a kubernetes client for the specified context. // If contextName is empty, it uses the current/default context. func getClientForContext(contextName string) (client.Client, error) { @@ -215,60 +358,134 @@ func getClientForContext(contextName string) (client.Client, error) { } func main() { + const ( + defaultSort = "vm" + defaultPostgresSecret = "cortex-nova-postgres" + defaultViews = "all" + ) + // Parse command line flags - sortBy := flag.String("sort", "vm", "Sort VMs by: vm (UUID), vm-host (VM's host), res-host (reservation host)") - postgresSecret := flag.String("postgres-secret", "cortex-nova-postgres", "Name of the kubernetes secret containing postgres credentials") + configPath := flag.String("config", "", "Path to JSON config file") + sortBy := flag.String("sort", defaultSort, "Sort VMs by: vm (UUID), vm-host (VM's host), res-host (reservation host)") + postgresSecret := flag.String("postgres-secret", defaultPostgresSecret, "Name of the kubernetes secret containing postgres credentials") namespace := flag.String("namespace", "", "Namespace of the postgres secret (defaults to 'default')") postgresHostOverride := flag.String("postgres-host", "", "Override postgres host (useful with port-forward, e.g., localhost)") postgresPortOverride := flag.String("postgres-port", "", "Override postgres port (useful with port-forward, e.g., 5432)") - viewsFlag := flag.String("views", "all", "Comma-separated list of views to show (all, hypervisors, vms, reservations, summary, hypervisor-summary, validation, stale, without-res, not-in-db, by-host)") + viewsFlag := flag.String("views", defaultViews, "Comma-separated list of views to show (all, hypervisors, vms, reservations, summary, hypervisor-summary, validation, stale, without-res, not-in-db, by-host)") hideFlag := flag.String("hide", "", "Comma-separated list of views to hide (applied after --views)") filterName := flag.String("filter-name", "", "Filter hypervisors by name (substring match)") filterTrait := flag.String("filter-trait", "", "Filter hypervisors by trait (e.g., CUSTOM_HANA_EXCLUSIVE_HOST)") - hypervisorContext := flag.String("hypervisor-context", "", "Kubernetes context for reading Hypervisors (default: current context)") - reservationContext := flag.String("reservation-context", "", "Kubernetes context for reading Reservations (default: current context)") + hypervisorContext := flag.String("hypervisor-context", "", "Kubernetes context for reading Hypervisors (default: current context); kept for backward compatibility") + hypervisorContexts := flag.String("hypervisor-contexts", "", "Comma-separated kubernetes contexts for reading Hypervisors (default: current context)") + reservationContext := flag.String("reservation-context", "", "Kubernetes context for reading Reservations (default: current context); kept for backward compatibility") + reservationContexts := flag.String("reservation-contexts", "", "Comma-separated kubernetes contexts for reading Reservations (default: current context)") postgresContext := flag.String("postgres-context", "", "Kubernetes context for reading postgres secret (default: current context)") + postgresPortForward := flag.Bool("postgres-port-forward", false, "Automatically run kubectl port-forward for postgres before connecting") + postgresPortForwardService := flag.String("postgres-port-forward-service", "", "Service name used for postgres port-forward (defaults to secret host)") + postgresPortForwardLocalPort := flag.String("postgres-port-forward-local-port", "", "Local port to use for postgres port-forward (defaults to postgres port)") + postgresPortForwardRemotePort := flag.String("postgres-port-forward-remote-port", "", "Remote postgres service port for port-forward (defaults to postgres port)") flag.Parse() + cfg, err := loadConfig(*configPath) + if err != nil { + fmt.Fprintf(os.Stderr, "Error loading config: %v\n", err) + os.Exit(1) + } + + if *sortBy == defaultSort && cfg.Sort != "" { + *sortBy = cfg.Sort + } + if *postgresSecret == defaultPostgresSecret && cfg.PostgresSecret != "" { + *postgresSecret = cfg.PostgresSecret + } + if *namespace == "" && cfg.Namespace != "" { + *namespace = cfg.Namespace + } + if *postgresHostOverride == "" && cfg.PostgresHost != "" { + *postgresHostOverride = cfg.PostgresHost + } + if *postgresPortOverride == "" && cfg.PostgresPort != "" { + *postgresPortOverride = cfg.PostgresPort + } + if *viewsFlag == defaultViews && cfg.Views != "" { + *viewsFlag = cfg.Views + } + if *hideFlag == "" && cfg.Hide != "" { + *hideFlag = cfg.Hide + } + if *filterName == "" && cfg.FilterName != "" { + *filterName = cfg.FilterName + } + if *filterTrait == "" && cfg.FilterTrait != "" { + *filterTrait = cfg.FilterTrait + } + if *hypervisorContext == "" && cfg.HypervisorContext != "" { + *hypervisorContext = cfg.HypervisorContext + } + if *hypervisorContexts == "" && len(cfg.HypervisorContexts) > 0 { + *hypervisorContexts = strings.Join(cfg.HypervisorContexts, ",") + } + if *reservationContext == "" && cfg.ReservationContext != "" { + *reservationContext = cfg.ReservationContext + } + if *reservationContexts == "" && len(cfg.ReservationContexts) > 0 { + *reservationContexts = strings.Join(cfg.ReservationContexts, ",") + } + if *postgresContext == "" && cfg.PostgresContext != "" { + *postgresContext = cfg.PostgresContext + } + if !*postgresPortForward && cfg.PostgresPortForward { + *postgresPortForward = true + } + if *postgresPortForwardService == "" && cfg.PostgresPortForwardService != "" { + *postgresPortForwardService = cfg.PostgresPortForwardService + } + if *postgresPortForwardLocalPort == "" && cfg.PostgresPortForwardLocal != "" { + *postgresPortForwardLocalPort = cfg.PostgresPortForwardLocal + } + if *postgresPortForwardRemotePort == "" && cfg.PostgresPortForwardRemote != "" { + *postgresPortForwardRemotePort = cfg.PostgresPortForwardRemote + } + views := parseViews(*viewsFlag) applyHideViews(views, *hideFlag) ctx := context.Background() - // Create kubernetes clients for hypervisors and reservations - // They may use different contexts if specified - hvClient, err := getClientForContext(*hypervisorContext) - if err != nil { - fmt.Fprintf(os.Stderr, "Error creating hypervisor client: %v\n", err) - os.Exit(1) + hypervisorContextList := append(splitCSV(*hypervisorContexts), splitCSV(*hypervisorContext)...) + if len(hypervisorContextList) == 0 { + hypervisorContextList = []string{""} } + hypervisorContextList = uniqueStrings(hypervisorContextList) - // Reuse the same client if contexts are the same, otherwise create a new one - var resClient client.Client - if *reservationContext == *hypervisorContext { - resClient = hvClient - } else { - resClient, err = getClientForContext(*reservationContext) + reservationContextList := append(splitCSV(*reservationContexts), splitCSV(*reservationContext)...) + if len(reservationContextList) == 0 { + reservationContextList = []string{""} + } + reservationContextList = uniqueStrings(reservationContextList) + showReservationContextInName := len(reservationContextList) > 1 + + // Create kubernetes clients for hypervisors and reservations + // They may use different contexts if specified + clientByContext := make(map[string]client.Client) + getOrCreateClient := func(contextName string) (client.Client, error) { + if c, ok := clientByContext[contextName]; ok { + return c, nil + } + c, err := getClientForContext(contextName) if err != nil { - fmt.Fprintf(os.Stderr, "Error creating reservation client: %v\n", err) - os.Exit(1) + return nil, err } + clientByContext[contextName] = c + return c, nil } // Create postgres client (for reading the secret) // This is typically the local cluster where cortex runs - var pgClient client.Client - switch *postgresContext { - case *hypervisorContext: - pgClient = hvClient - case *reservationContext: - pgClient = resClient - default: - pgClient, err = getClientForContext(*postgresContext) - if err != nil { - fmt.Fprintf(os.Stderr, "Error creating postgres client: %v\n", err) - os.Exit(1) - } + pgClient, err := getOrCreateClient(*postgresContext) + if err != nil { + fmt.Fprintf(os.Stderr, "Error creating postgres client: %v\n", err) + os.Exit(1) } // Determine namespace @@ -277,20 +494,54 @@ func main() { ns = "default" // Default fallback } - // Try to connect to postgres (use pgClient for reading the secret) + // Try to connect to postgres var db *sql.DB var serverMap map[string]serverInfo var flavorMap map[string]flavorInfo - db, serverMap, flavorMap = connectToPostgres(ctx, pgClient, *postgresSecret, ns, *postgresHostOverride, *postgresPortOverride, *postgresContext) + var pgCleanup func() + db, serverMap, flavorMap, pgCleanup = connectToPostgres( + ctx, + pgClient, + *postgresSecret, + ns, + *postgresHostOverride, + *postgresPortOverride, + *postgresContext, + *postgresPortForward, + *postgresPortForwardService, + *postgresPortForwardLocalPort, + *postgresPortForwardRemotePort, + ) if db != nil { defer db.Close() } + if pgCleanup != nil { + defer pgCleanup() + } - // Get all hypervisors to find all VMs (use hvClient) + // Get all hypervisors to find all VMs from all configured hypervisor contexts. var allHypervisors hv1.HypervisorList - if err := hvClient.List(ctx, &allHypervisors); err != nil { - fmt.Fprintf(os.Stderr, "Error listing hypervisors: %v\n", err) + hypervisorReadErrors := 0 + for _, hvCtx := range hypervisorContextList { + hvClient, err := getOrCreateClient(hvCtx) + if err != nil { + fmt.Fprintf(os.Stderr, "Warning: Error creating hypervisor client for context %q: %v\n", contextDisplayName(hvCtx), err) + hypervisorReadErrors++ + continue + } + + var hypervisorsInContext hv1.HypervisorList + if err := hvClient.List(ctx, &hypervisorsInContext); err != nil { + fmt.Fprintf(os.Stderr, "Warning: Error listing hypervisors in context %q: %v\n", contextDisplayName(hvCtx), err) + hypervisorReadErrors++ + continue + } + + allHypervisors.Items = append(allHypervisors.Items, hypervisorsInContext.Items...) + } + if len(allHypervisors.Items) == 0 && hypervisorReadErrors > 0 { + fmt.Fprintf(os.Stderr, "Error: could not list hypervisors from any context (errors: %d)\n", hypervisorReadErrors) return } @@ -333,10 +584,32 @@ func main() { } } - // Get all reservations (both failover and committed) (use resClient) + // Get all reservations (both failover and committed) from all configured reservation contexts. var allReservations v1alpha1.ReservationList - if err := resClient.List(ctx, &allReservations); err != nil { - fmt.Fprintf(os.Stderr, "Error listing reservations: %v\n", err) + reservationReadErrors := 0 + for _, resCtx := range reservationContextList { + resClient, err := getOrCreateClient(resCtx) + if err != nil { + fmt.Fprintf(os.Stderr, "Warning: Error creating reservation client for context %q: %v\n", contextDisplayName(resCtx), err) + reservationReadErrors++ + continue + } + + var reservationsInContext v1alpha1.ReservationList + if err := resClient.List(ctx, &reservationsInContext); err != nil { + fmt.Fprintf(os.Stderr, "Warning: Error listing reservations in context %q: %v\n", contextDisplayName(resCtx), err) + reservationReadErrors++ + continue + } + + for _, reservation := range reservationsInContext.Items { + resCopy := reservation + resCopy.Name = annotateReservationName(reservation.Name, resCtx, showReservationContextInName) + allReservations.Items = append(allReservations.Items, resCopy) + } + } + if len(allReservations.Items) == 0 { + fmt.Fprintf(os.Stderr, "Error: could not list reservations from any context (errors: %d)\n", reservationReadErrors) return } @@ -1010,20 +1283,32 @@ func main() { printHeader("Summary Statistics") // Kubernetes context information - hvCtx := *hypervisorContext - if hvCtx == "" { - hvCtx = "(current context)" - } - resCtx := *reservationContext - if resCtx == "" { - resCtx = "(current context)" - } pgCtx := *postgresContext if pgCtx == "" { pgCtx = "(current context)" } - fmt.Printf("Hypervisor context: %s\n", hvCtx) - fmt.Printf("Reservation context: %s\n", resCtx) + if len(hypervisorContextList) == 1 { + fmt.Printf("Hypervisor context: %s\n", contextDisplayName(hypervisorContextList[0])) + } else { + fmt.Printf("Hypervisor contexts: %s\n", strings.Join(func() []string { + display := make([]string, 0, len(hypervisorContextList)) + for _, c := range hypervisorContextList { + display = append(display, contextDisplayName(c)) + } + return display + }(), ", ")) + } + if len(reservationContextList) == 1 { + fmt.Printf("Reservation context: %s\n", contextDisplayName(reservationContextList[0])) + } else { + fmt.Printf("Reservation contexts: %s\n", strings.Join(func() []string { + display := make([]string, 0, len(reservationContextList)) + for _, c := range reservationContextList { + display = append(display, contextDisplayName(c)) + } + return display + }(), ", ")) + } fmt.Printf("Postgres context: %s\n", pgCtx) fmt.Println() @@ -1346,7 +1631,20 @@ func printHypervisorSummary(hypervisors []hv1.Hypervisor, reservations []v1alpha fmt.Println() } -func connectToPostgres(ctx context.Context, k8sClient client.Client, secretName, namespace, hostOverride, portOverride, contextName string) (db *sql.DB, serverMap map[string]serverInfo, flavorMap map[string]flavorInfo) { +func connectToPostgres( + ctx context.Context, + k8sClient client.Client, + secretName, + namespace, + hostOverride, + portOverride, + contextName string, + enablePortForward bool, + portForwardService, + portForwardLocalPort, + portForwardRemotePort string, +) (db *sql.DB, serverMap map[string]serverInfo, flavorMap map[string]flavorInfo, cleanup func()) { + ctxDisplay := contextName if ctxDisplay == "" { ctxDisplay = "(current context)" @@ -1362,7 +1660,7 @@ func connectToPostgres(ctx context.Context, k8sClient client.Client, secretName, fmt.Fprintf(os.Stderr, "Warning: Could not get postgres secret '%s' in namespace '%s' (context: %s): %v\n", secretName, namespace, ctxDisplay, err) fmt.Fprintf(os.Stderr, " Postgres features will be disabled.\n") fmt.Fprintf(os.Stderr, " Use --postgres-secret, --namespace, and --postgres-context flags to specify the secret location.\n\n") - return nil, nil, nil + return nil, nil, nil, nil } // Extract connection details @@ -1374,7 +1672,7 @@ func connectToPostgres(ctx context.Context, k8sClient client.Client, secretName, if user == "" || password == "" || database == "" { fmt.Fprintf(os.Stderr, "Warning: Postgres secret is missing required fields (user, password, database)\n") - return nil, nil, nil + return nil, nil, nil, nil } if port == "" { @@ -1389,17 +1687,50 @@ func connectToPostgres(ctx context.Context, k8sClient client.Client, secretName, password = strip(password) database = strip(database) - // Apply overrides if provided - if hostOverride != "" { - host = hostOverride - } - if portOverride != "" { - port = portOverride + if host == "" && !enablePortForward { + fmt.Fprintf(os.Stderr, "Warning: Postgres host is empty. Use --postgres-host to specify.\n") + return nil, nil, nil, nil } - if host == "" { - fmt.Fprintf(os.Stderr, "Warning: Postgres host is empty. Use --postgres-host to specify.\n") - return nil, nil, nil + // Port-forward uses the original secret host as the service name (before overrides) + if enablePortForward { + pfService := portForwardService + if pfService == "" { + pfService = host + } + + pfRemotePort := portForwardRemotePort + if pfRemotePort == "" { + pfRemotePort = port + } + + pfLocalPort := portForwardLocalPort + if pfLocalPort == "" { + if portOverride != "" { + pfLocalPort = portOverride + } else { + pfLocalPort = pfRemotePort + } + } + + fmt.Fprintf(os.Stderr, "Postgres: Starting port-forward %s localhost:%s->%s (context: %s)\n", pfService, pfLocalPort, pfRemotePort, ctxDisplay) + pfCleanup, err := startPortForward(namespace, contextName, pfService, pfLocalPort, pfRemotePort) + if err != nil { + fmt.Fprintf(os.Stderr, "Warning: Failed to start postgres port-forward: %v\n", err) + fmt.Fprintf(os.Stderr, " Continuing without automatic port-forward.\n") + } else { + cleanup = pfCleanup + host = "localhost" + port = pfLocalPort + } + } else { + // Apply overrides only when not using port-forward + if hostOverride != "" { + host = hostOverride + } + if portOverride != "" { + port = portOverride + } } // Connect to postgres @@ -1409,7 +1740,10 @@ func connectToPostgres(ctx context.Context, k8sClient client.Client, secretName, db, err := sql.Open("postgres", connStr) if err != nil { fmt.Fprintf(os.Stderr, "Warning: Could not connect to postgres: %v\n", err) - return nil, nil, nil + if cleanup != nil { + cleanup() + } + return nil, nil, nil, nil } // Test connection @@ -1419,7 +1753,10 @@ func connectToPostgres(ctx context.Context, k8sClient client.Client, secretName, fmt.Fprintf(os.Stderr, " kubectl port-forward svc/%s %s:%s -n %s\n", host, port, port, namespace) fmt.Fprintf(os.Stderr, " ./visualize-reservations --postgres-host=localhost --postgres-port=%s\n\n", port) db.Close() - return nil, nil, nil + if cleanup != nil { + cleanup() + } + return nil, nil, nil, nil } // Query servers with host information @@ -1462,7 +1799,7 @@ func connectToPostgres(ctx context.Context, k8sClient client.Client, secretName, } } - return db, serverMap, flavorMap + return db, serverMap, flavorMap, cleanup } func printHeader(title string) {