Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 46 additions & 16 deletions internal/cli/pdsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func newExecCmd(opts *Options) *cobra.Command {
var logDir string
var noPrefix bool
var fanout int
var readyOnly bool

cmd := &cobra.Command{
Use: "exec [session]",
Expand Down Expand Up @@ -65,7 +66,7 @@ func newExecCmd(opts *Options) *cobra.Command {
}

if multiPod || detach {
return runMultiPodExec(cmd, cc, commandArgs, podNames, role, labels, exclude, container, detach, timeout, logDir, noPrefix, fanout)
return runMultiPodExec(cmd, cc, commandArgs, podNames, role, labels, exclude, container, detach, timeout, logDir, noPrefix, fanout, readyOnly)
}

// Single-pod mode (existing behavior)
Expand Down Expand Up @@ -101,6 +102,7 @@ func newExecCmd(opts *Options) *cobra.Command {
cmd.Flags().StringVar(&logDir, "log-dir", "", "Write per-pod logs to directory")
cmd.Flags().BoolVar(&noPrefix, "no-prefix", false, "Suppress pod name prefix in output")
cmd.Flags().IntVar(&fanout, "fanout", pdshDefaultFanout, "Maximum concurrent pod executions")
cmd.Flags().BoolVar(&readyOnly, "ready-only", false, "Run only on pods that are already running (skip readiness check)")
return cmd
}

Expand Down Expand Up @@ -147,30 +149,52 @@ func validateMultiPodFlags(podNames []string, role string, labels []string, excl
return nil
}

func runMultiPodExec(cmd *cobra.Command, cc *commandContext, commandArgs []string, podNames []string, role string, labels []string, exclude []string, container string, detach bool, timeout time.Duration, logDir string, noPrefix bool, fanout int) error {
func runMultiPodExec(cmd *cobra.Command, cc *commandContext, commandArgs []string, podNames []string, role string, labels []string, exclude []string, container string, detach bool, timeout time.Duration, logDir string, noPrefix bool, fanout int, readyOnly bool) error {
ctx := cmd.Context()
labelSel := selectorForSessionRun(cc.sessionName)
sessionPods, err := cc.kube.ListPods(ctx, cc.namespace, false, labelSel)
if err != nil {
return fmt.Errorf("list session pods: %w", err)
}
pods := filterRunningPods(sessionPods)

// Apply user-specified filters before the readiness check so the
// running-vs-total comparison only considers targeted pods.
allPods := sessionPods
switch {
case len(podNames) > 0:
pods = filterPodsByName(pods, podNames)
allPods = filterPodsByName(allPods, podNames)
case role != "":
pods = filterPodsByRole(pods, role)
allPods = filterPodsByRole(allPods, role)
case len(labels) > 0:
pods = filterPodsByLabels(pods, labels)
allPods = filterPodsByLabels(allPods, labels)
}

if len(exclude) > 0 {
pods = excludePods(pods, exclude)
allPods = excludePods(allPods, exclude)
}

pods := filterRunningPods(allPods)

if len(allPods) == 0 {
return fmt.Errorf("no pods match the specified filters in session %q", cc.sessionName)
}

if len(pods) == 0 {
return fmt.Errorf("no running pods match the specified filters in session %q", cc.sessionName)
return fmt.Errorf("no running pods in session %q (0/%d pods ready)", cc.sessionName, len(allPods))
}

if len(pods) < len(allPods) && !readyOnly {
notReady := make([]string, 0, len(allPods)-len(pods))
runningSet := make(map[string]bool, len(pods))
for _, p := range pods {
runningSet[p.Name] = true
}
for _, p := range allPods {
if !runningSet[p.Name] {
notReady = append(notReady, fmt.Sprintf("%s (%s)", p.Name, p.Phase))
}
}
return fmt.Errorf("%d/%d pods are not running: %s\nUse --ready-only to run on the %d ready pods",
len(notReady), len(allPods), strings.Join(notReady, ", "), len(pods))
}

targetContainer := container
Expand Down Expand Up @@ -215,6 +239,8 @@ func runMultiExec(ctx context.Context, client connect.ExecClient, namespace stri
podNames[i] = p.Name
}
shortNames := shortPodNames(podNames)
colorEnabled := isInteractiveWriter(stdout)
displayPrefixes := formatPodPrefixes(shortNames, colorEnabled)

var writeMu sync.Mutex
noPrefixOut := &lockedWriter{w: stdout}
Expand All @@ -225,7 +251,7 @@ func runMultiExec(ctx context.Context, client connect.ExecClient, namespace stri
var wg sync.WaitGroup
for i, pod := range pods {
wg.Add(1)
go func(pod kube.PodSummary, shortName string) {
go func(pod kube.PodSummary, shortName, displayPrefix string) {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
Expand All @@ -242,8 +268,8 @@ func runMultiExec(ctx context.Context, client connect.ExecClient, namespace stri
podStdout = noPrefixOut
podStderr = noPrefixErr
} else {
pw := newPrefixedWriter(shortName, stdout, &writeMu)
pe := newPrefixedWriter(shortName, stderr, &writeMu)
pw := newPrefixedWriter(displayPrefix, stdout, &writeMu)
pe := newPrefixedWriter(displayPrefix, stderr, &writeMu)
defer pw.Flush()
defer pe.Flush()
podStdout = pw
Expand All @@ -264,7 +290,7 @@ func runMultiExec(ctx context.Context, client connect.ExecClient, namespace stri

err := connect.RunOnContainer(execCtx, client, namespace, pod.Name, container, command, false, nil, podStdout, podStderr)
results <- podExecResult{pod: pod.Name, err: err}
}(pod, shortNames[i])
}(pod, shortNames[i], displayPrefixes[i])
}

go func() {
Expand Down Expand Up @@ -305,6 +331,8 @@ func runDetachExec(ctx context.Context, client connect.ExecClient, namespace str
podNames[i] = p.Name
}
shortNames := shortPodNames(podNames)
colorEnabled := isInteractiveWriter(out)
displayPrefixes := formatPodPrefixes(shortNames, colorEnabled)
command := detachCommand(cmdStr)

results := make(chan podExecResult, len(pods))
Expand All @@ -327,19 +355,21 @@ func runDetachExec(ctx context.Context, client connect.ExecClient, namespace str
close(results)
}()

displayMap := make(map[string]string, len(podNames))
nameMap := make(map[string]string, len(podNames))
for i, n := range podNames {
displayMap[n] = displayPrefixes[i]
nameMap[n] = shortNames[i]
}

var failCount int
for r := range results {
short := nameMap[r.pod]
prefix := displayMap[r.pod]
if r.err != nil {
fmt.Fprintf(out, "%s: error: %v\n", short, r.err)
fmt.Fprintf(out, "%s error: %v\n", prefix, r.err)
failCount++
} else {
fmt.Fprintf(out, "%s: detached\n", short)
fmt.Fprintf(out, "%s detached\n", prefix)
}
}
if failCount > 0 {
Expand Down
14 changes: 7 additions & 7 deletions internal/cli/pdsh_fanout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,10 @@ func TestMultiExecPipelineAllPodsWithPrefixedOutput(t *testing.T) {

// Step 4: verify prefixed output
output := stdout.String()
if !strings.Contains(output, "worker-0: gpu0: NVIDIA A100") {
if !strings.Contains(output, "[worker-0] gpu0: NVIDIA A100") {
t.Fatalf("expected prefixed output for worker-0, got %q", output)
}
if !strings.Contains(output, "worker-1: gpu1: NVIDIA H100") {
if !strings.Contains(output, "[worker-1] gpu1: NVIDIA H100") {
t.Fatalf("expected prefixed output for worker-1, got %q", output)
}
}
Expand Down Expand Up @@ -327,13 +327,13 @@ func TestMultiExecPipelinePartialFailureWithDetach(t *testing.T) {
}

output := stdout.String()
if !strings.Contains(output, "worker-0: detached") {
if !strings.Contains(output, "worker-0] detached") {
t.Fatalf("expected worker-0 detached, got %q", output)
}
if !strings.Contains(output, "worker-1: error: container not ready") {
if !strings.Contains(output, "worker-1] error: container not ready") {
t.Fatalf("expected worker-1 error, got %q", output)
}
if !strings.Contains(output, "worker-2: detached") {
if !strings.Contains(output, "worker-2] detached") {
t.Fatalf("expected worker-2 detached, got %q", output)
}
}
Expand Down Expand Up @@ -361,7 +361,7 @@ func TestMultiExecPipelineNoPrefixModeMultiplePods(t *testing.T) {

output := stdout.String()
// In no-prefix mode, output should NOT contain the pod name prefix.
if strings.Contains(output, "worker-0:") || strings.Contains(output, "worker-1:") {
if strings.Contains(output, "[worker-0]") || strings.Contains(output, "[worker-1]") {
t.Fatalf("expected no prefix in output, got %q", output)
}
if !strings.Contains(output, "alpha") || !strings.Contains(output, "bravo") {
Expand Down Expand Up @@ -526,7 +526,7 @@ func TestMultiExecFullPipelineRoleFilterFanoutLogDir(t *testing.T) {

// Verify prefixed output present.
output := stdout.String()
if !strings.Contains(output, "worker-0:") || !strings.Contains(output, "worker-1:") || !strings.Contains(output, "worker-3:") {
if !strings.Contains(output, "[worker-0]") || !strings.Contains(output, "[worker-1]") || !strings.Contains(output, "[worker-3]") {
t.Fatalf("expected prefixed output for workers 0,1,3, got %q", output)
}
}
28 changes: 24 additions & 4 deletions internal/cli/pdsh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func TestRunDetachExec(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(stdout.String(), "worker-0: detached") || !strings.Contains(stdout.String(), "worker-1: detached") {
if !strings.Contains(stdout.String(), "worker-0] detached") || !strings.Contains(stdout.String(), "worker-1] detached") {
t.Fatalf("expected detach confirmation, got %q", stdout.String())
}
}
Expand All @@ -193,10 +193,10 @@ func TestRunDetachExecWithError(t *testing.T) {
if err == nil {
t.Fatal("expected error for partial failure")
}
if !strings.Contains(stdout.String(), "worker-0: detached") {
if !strings.Contains(stdout.String(), "worker-0] detached") {
t.Fatalf("expected detach for worker-0, got %q", stdout.String())
}
if !strings.Contains(stdout.String(), "worker-1: error:") {
if !strings.Contains(stdout.String(), "worker-1] error:") {
t.Fatalf("expected error for worker-1, got %q", stdout.String())
}
}
Expand All @@ -218,7 +218,7 @@ func TestRunMultiExecSuccess(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(stdout.String(), "worker-0: ok") || !strings.Contains(stdout.String(), "worker-1: ok") {
if !strings.Contains(stdout.String(), "worker-0] ok") || !strings.Contains(stdout.String(), "worker-1] ok") {
t.Fatalf("expected prefixed output, got %q", stdout.String())
}
}
Expand Down Expand Up @@ -373,3 +373,23 @@ func TestValidateMultiPodFlags(t *testing.T) {
})
}
}

func TestFilterRunningPodsReadinessCheck(t *testing.T) {
allPods := []kube.PodSummary{
{Name: "leader-0", Phase: "Running"},
{Name: "worker-0", Phase: "Pending"},
{Name: "worker-1", Phase: "ContainerCreating"},
}
running := filterRunningPods(allPods)

// Without --ready-only, should detect the gap.
if len(running) == len(allPods) {
t.Fatal("expected fewer running pods than total")
}
if len(running) != 1 {
t.Fatalf("expected 1 running pod, got %d", len(running))
}
if running[0].Name != "leader-0" {
t.Fatalf("expected leader-0, got %s", running[0].Name)
}
}
45 changes: 43 additions & 2 deletions internal/cli/pdsh_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,47 @@ func shortPodNames(names []string) []string {
return out
}

func maxPrefixWidth(names []string) int {
w := 0
for _, n := range names {
if len(n) > w {
w = len(n)
}
}
return w
}

var podPrefixColors = []string{
"\033[36m", // cyan
"\033[33m", // yellow
"\033[32m", // green
"\033[35m", // magenta
"\033[34m", // blue
"\033[91m", // bright red
"\033[96m", // bright cyan
"\033[93m", // bright yellow
}

func podPrefixColor(index int) string {
return podPrefixColors[index%len(podPrefixColors)]
}

const prefixReset = "\033[0m"

func formatPodPrefixes(shortNames []string, color bool) []string {
width := maxPrefixWidth(shortNames)
out := make([]string, len(shortNames))
for i, name := range shortNames {
padded := fmt.Sprintf("%-*s", width, name)
if color {
out[i] = podPrefixColor(i) + "[" + padded + "]" + prefixReset
} else {
out[i] = "[" + padded + "]"
}
}
return out
}

// prefixedWriter is a thread-safe io.Writer that buffers input and emits
// complete lines prefixed with the pod short name. Partial lines are held
// until a newline arrives or Flush is called.
Expand All @@ -81,7 +122,7 @@ func (w *prefixedWriter) Write(p []byte) (int, error) {
break
}
w.mu.Lock()
fmt.Fprintf(w.dest, "%s: %s", w.prefix, line)
fmt.Fprintf(w.dest, "%s %s", w.prefix, line)
w.mu.Unlock()
}
return len(p), nil
Expand All @@ -93,7 +134,7 @@ func (w *prefixedWriter) Flush() {
return
}
w.mu.Lock()
fmt.Fprintf(w.dest, "%s: %s\n", w.prefix, w.buf.String())
fmt.Fprintf(w.dest, "%s %s\n", w.prefix, w.buf.String())
w.mu.Unlock()
w.buf.Reset()
}
Loading
Loading