Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,9 @@ func (cmd *Command) appendFlag(fl Flag) {

// VisiblePersistentFlags returns a slice of [LocalFlag] with Persistent=true and Hidden=false.
func (cmd *Command) VisiblePersistentFlags() []Flag {
if cmd.isCompletionCommand {
return nil
}
var flags []Flag
for _, fl := range cmd.Root().Flags {
pfl, ok := fl.(LocalFlag)
Expand Down
6 changes: 0 additions & 6 deletions command_run.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,6 @@ func (cmd *Command) run(ctx context.Context, osArgs []string) (_ context.Context
var rargs Args = &stringSliceArgs{v: osArgs}
var args Args = &stringSliceArgs{rargs.Tail()}

if cmd.isCompletionCommand {
tracef("completion command detected, skipping pre-parse (cmd=%[1]q)", cmd.Name)
cmd.parsedArgs = args
return ctx, cmd.Action(ctx, cmd)
}

for _, f := range cmd.allFlags() {
if err := f.PreParse(); err != nil {
return ctx, err
Expand Down
60 changes: 25 additions & 35 deletions completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"embed"
"fmt"
"sort"
"strings"
)

Expand Down Expand Up @@ -58,45 +57,36 @@ Output the script to path/to/autocomplete/$COMMAND.ps1 an run it.
`

func buildCompletionCommand(appName string) *Command {
return &Command{
Name: completionCommandName,
Hidden: true,
Usage: "Output shell completion script for bash, zsh, fish, or Powershell",
Description: strings.ReplaceAll(completionDescription, "$COMMAND", appName),
Action: func(ctx context.Context, cmd *Command) error {
return printShellCompletion(ctx, cmd, appName)
},
cmd := &Command{
Name: completionCommandName,
Hidden: true,
Usage: "Output shell completion script for bash, zsh, fish, or Powershell",
Description: strings.ReplaceAll(completionDescription, "$COMMAND", appName),
isCompletionCommand: true,
}
}

func printShellCompletion(_ context.Context, cmd *Command, appName string) error {
var shells []string
for k := range shellCompletions {
shells = append(shells, k)
}

sort.Strings(shells)

if cmd.Args().Len() == 0 {
return Exit(fmt.Sprintf("no shell provided for completion command. available shells are %+v", shells), 1)
}
s := cmd.Args().First()

renderCompletion, ok := shellCompletions[s]
if !ok {
return Exit(fmt.Sprintf("unknown shell %s, available shells are %+v", s, shells), 1)
for shell, render := range shellCompletions {
cmd.Commands = append(cmd.Commands, buildShellCompletionSubcommand(shell, render, appName))
}

completionScript, err := renderCompletion(cmd, appName)
if err != nil {
return Exit(err, 1)
}
return cmd
}

_, err = cmd.Writer.Write([]byte(completionScript))
if err != nil {
return Exit(err, 1)
func buildShellCompletionSubcommand(shell string, render renderCompletion, appName string) *Command {
return &Command{
Name: shell,
Usage: fmt.Sprintf("Output %s completion script", shell),
isCompletionCommand: true,
Action: func(ctx context.Context, cmd *Command) error {
completionScript, err := render(cmd, appName)
if err != nil {
return Exit(err, 1)
}
_, err = cmd.Root().Writer.Write([]byte(completionScript))
if err != nil {
return Exit(err, 1)
}
return nil
},
}

return nil
}
124 changes: 96 additions & 28 deletions completion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,53 @@ import (
"github.com/stretchr/testify/require"
)

func TestCompletionHelp(t *testing.T) {
tests := []struct {
name string
args []string
}{
{
name: "short help flag",
args: []string{"foo", completionCommandName, "-h"},
},
{
name: "long help flag",
args: []string{"foo", completionCommandName, "--help"},
},
{
name: "completion bash short help flag",
args: []string{"foo", completionCommandName, "bash", "-h"},
},
{
name: "completion bash long help flag",
args: []string{"foo", completionCommandName, "bash", "--help"},
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
out := &bytes.Buffer{}

cmd := &Command{
EnableShellCompletion: true,
Writer: out,
Flags: []Flag{
&StringFlag{
Name: "required-flag",
Required: true,
},
},
}

r := require.New(t)

r.NoError(cmd.Run(buildTestContext(t), test.args))
r.Contains(out.String(), "USAGE")
r.NotContains(out.String(), "GLOBAL OPTIONS")
})
}
}

func TestCompletionDisable(t *testing.T) {
cmd := &Command{}

Expand All @@ -19,8 +66,11 @@ func TestCompletionDisable(t *testing.T) {
}

func TestCompletionEnable(t *testing.T) {
out := &bytes.Buffer{}

cmd := &Command{
EnableShellCompletion: true,
Writer: out,
Flags: []Flag{
&StringFlag{
Name: "goo",
Expand All @@ -29,18 +79,23 @@ func TestCompletionEnable(t *testing.T) {
},
}

err := cmd.Run(buildTestContext(t), []string{"foo", completionCommandName})
assert.ErrorContains(t, err, "no shell provided")
r := require.New(t)
r.NoError(cmd.Run(buildTestContext(t), []string{"foo", completionCommandName}))
r.Contains(out.String(), "USAGE")
}

func TestCompletionEnableDiffCommandName(t *testing.T) {
out := &bytes.Buffer{}

cmd := &Command{
EnableShellCompletion: true,
ShellCompletionCommandName: "junky",
Writer: out,
}

err := cmd.Run(buildTestContext(t), []string{"foo", "junky"})
assert.ErrorContains(t, err, "no shell provided")
r := require.New(t)
r.NoError(cmd.Run(buildTestContext(t), []string{"foo", "junky"}))
r.Contains(out.String(), "USAGE")
}

func TestCompletionShell(t *testing.T) {
Expand All @@ -56,10 +111,7 @@ func TestCompletionShell(t *testing.T) {
r := require.New(t)

r.NoError(cmd.Run(buildTestContext(t), []string{"foo", completionCommandName, k}))
r.Containsf(
k, out.String(),
"Expected output to contain shell name %[1]q", k,
)
r.NotEmpty(out.String(), "Expected non-empty completion output for shell %q", k)
})
}
}
Expand Down Expand Up @@ -255,25 +307,18 @@ func TestCompletionSubcommand(t *testing.T) {
}
}

type mockWriter struct {
err error
}

func (mw *mockWriter) Write(p []byte) (int, error) {
if mw.err != nil {
return 0, mw.err
}
return len(p), nil
}

func TestCompletionInvalidShell(t *testing.T) {
cmd := &Command{
EnableShellCompletion: true,
}

unknownShellName := "junky-sheell"
err := cmd.Run(buildTestContext(t), []string{"foo", completionCommandName, unknownShellName})
assert.ErrorContains(t, err, "unknown shell junky-sheell")
assert.ErrorContains(t, err, fmt.Sprintf("No help topic for '%s'", unknownShellName))
Comment thread
dearchap marked this conversation as resolved.
}

func TestCompletionShellRenderError(t *testing.T) {
unknownShellName := "junky-sheell"

enableError := true
shellCompletions[unknownShellName] = func(c *Command, appName string) (string, error) {
Expand All @@ -286,16 +331,39 @@ func TestCompletionInvalidShell(t *testing.T) {
delete(shellCompletions, unknownShellName)
}()

err = cmd.Run(buildTestContext(t), []string{"foo", completionCommandName, unknownShellName})
cmd := &Command{
EnableShellCompletion: true,
}

err := cmd.Run(buildTestContext(t), []string{"foo", completionCommandName, unknownShellName})
assert.ErrorContains(t, err, "cant do completion")
}

// now disable shell completion error
enableError = false
c := cmd.Command(completionCommandName)
assert.NotNil(t, c)
c.Writer = &mockWriter{
err: fmt.Errorf("writer error"),
type mockWriter struct {
err error
}

func (mw *mockWriter) Write(p []byte) (int, error) {
if mw.err != nil {
return 0, mw.err
}
return len(p), nil
}

func TestCompletionShellWriteError(t *testing.T) {
shellName := "mock-shell"
shellCompletions[shellName] = func(c *Command, appName string) (string, error) {
return "something", nil
}
err = cmd.Run(buildTestContext(t), []string{"foo", completionCommandName, unknownShellName})
defer func() {
delete(shellCompletions, shellName)
}()

cmd := &Command{
EnableShellCompletion: true,
Writer: &mockWriter{err: fmt.Errorf("writer error")},
}

err := cmd.Run(buildTestContext(t), []string{"foo", completionCommandName, shellName})
assert.ErrorContains(t, err, "writer error")
}
Loading