Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ type Command struct {
// An action to execute after any subcommands are run, but after the subcommand has finished
// It is run even if Action() panics
After AfterFunc `json:"-"`
// An action to validate arguments before the command is run. If non-nil, it
// is called before Before and Action. If the current command does not set
// ArgValidator, the nearest ancestor that does is used instead.
// Returning a non-nil error short-circuits the command.
ArgValidator ArgValidatorFunc `json:"-"`
// The function to call when this command is invoked
Action ActionFunc `json:"-"`
// Execute this function if the proper command cannot be found
Expand Down
17 changes: 17 additions & 0 deletions command_run.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,14 @@ func (cmd *Command) run(ctx context.Context, osArgs []string) (_ context.Context
// First, resolve the chain of nested commands up to the parent.
cmdChain := commandChain(cmd)

// Run ArgValidator from the nearest ancestor that sets one.
if validator := findArgValidator(cmd); validator != nil {
if err := validator(ctx, cmd); err != nil {
deferErr = cmd.handleExitCoder(ctx, err)
return ctx, deferErr
}
}

// Run Before actions in order.
if ctx, err = runBefore(ctx, cmdChain); err != nil {
deferErr = err
Expand Down Expand Up @@ -397,6 +405,15 @@ func commandChain(cmd *Command) []*Command {
return cmdChain
}

func findArgValidator(cmd *Command) ArgValidatorFunc {
for c := cmd; c != nil; c = c.parent {
if c.ArgValidator != nil {
return c.ArgValidator
}
}
return nil
}

func runBefore(ctx context.Context, cmdChain []*Command) (context.Context, error) {
for _, cmd := range cmdChain {
if cmd.Before == nil {
Expand Down
115 changes: 115 additions & 0 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6450,3 +6450,118 @@ func TestCommand_Walk_NilFn(t *testing.T) {
cmd := &Command{Name: "foo"}
assert.Nil(t, cmd.Walk(nil))
}

func TestCommand_ArgValidator_RunsBeforeAction(t *testing.T) {
var validated bool
var actionRan bool

cmd := &Command{
Name: "test",
ArgValidator: func(_ context.Context, _ *Command) error {
validated = true
return nil
},
Action: func(_ context.Context, _ *Command) error {
actionRan = true
return nil
},
}

err := cmd.Run(buildTestContext(t), []string{"test"})
require.NoError(t, err)
assert.True(t, validated)
assert.True(t, actionRan)
}

func TestCommand_ArgValidator_ErrorShortCircuitsAction(t *testing.T) {
var actionRan bool

cmd := &Command{
Name: "test",
ArgValidator: func(_ context.Context, _ *Command) error {
return fmt.Errorf("validation failed")
},
Action: func(_ context.Context, _ *Command) error {
actionRan = true
return nil
},
}

err := cmd.Run(buildTestContext(t), []string{"test"})
assert.ErrorContains(t, err, "validation failed")
assert.False(t, actionRan)
}

func TestCommand_ArgValidator_InheritsFromParent(t *testing.T) {
var validated bool

root := &Command{
Name: "root",
ArgValidator: func(_ context.Context, _ *Command) error {
validated = true
return nil
},
Commands: []*Command{
{
Name: "sub",
Action: func(_ context.Context, _ *Command) error { return nil },
},
},
}

err := root.Run(buildTestContext(t), []string{"root", "sub"})
require.NoError(t, err)
assert.True(t, validated)
}

func TestCommand_ArgValidator_SubcommandOverride(t *testing.T) {
var parentValidated bool
var childValidated bool

root := &Command{
Name: "root",
ArgValidator: func(_ context.Context, _ *Command) error {
parentValidated = true
return nil
},
Commands: []*Command{
{
Name: "sub",
ArgValidator: func(_ context.Context, _ *Command) error {
childValidated = true
return nil
},
Action: func(_ context.Context, _ *Command) error { return nil },
},
},
}

err := root.Run(buildTestContext(t), []string{"root", "sub"})
require.NoError(t, err)
assert.False(t, parentValidated, "should use child's validator, not parent's")
assert.True(t, childValidated)
}

func TestCommand_ArgValidator_RunsBeforeBefore(t *testing.T) {
var order []string

cmd := &Command{
Name: "test",
ArgValidator: func(_ context.Context, _ *Command) error {
order = append(order, "validator")
return nil
},
Before: func(_ context.Context, _ *Command) (context.Context, error) {
order = append(order, "before")
return nil, nil
},
Action: func(_ context.Context, _ *Command) error {
order = append(order, "action")
return nil
},
}

err := cmd.Run(buildTestContext(t), []string{"test"})
require.NoError(t, err)
assert.Equal(t, []string{"validator", "before", "action"}, order)
}
7 changes: 7 additions & 0 deletions funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ type AfterFunc func(context.Context, *Command) error
// ActionFunc is the action to execute when no subcommands are specified
type ActionFunc func(context.Context, *Command) error

// ArgValidatorFunc is an action to validate arguments before the command is run.
// If non-nil, it is called before the command's After and Action functions.
// Returning a non-nil error short-circuits the command and propagates as
// the exit error. If the current command does not set ArgValidator, the
// nearest ancestor that does is used instead.
type ArgValidatorFunc func(context.Context, *Command) error

// CommandNotFoundFunc is executed if the proper command cannot be found
type CommandNotFoundFunc func(context.Context, *Command, string)

Expand Down
12 changes: 12 additions & 0 deletions godoc-current.txt
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,13 @@ type AfterFunc func(context.Context, *Command) error
AfterFunc is an action that executes after any subcommands are run and have
finished. The AfterFunc is run even if Action() panics.

type ArgValidatorFunc func(context.Context, *Command) error
ArgValidatorFunc is an action to validate arguments before the command
is run. If non-nil, it is called before the command's After and Action
functions. Returning a non-nil error short-circuits the command and
propagates as the exit error. If the current command does not set
ArgValidator, the nearest ancestor that does is used instead.

type Args interface {
// Get returns the nth argument, or else a blank string
Get(n int) string
Expand Down Expand Up @@ -482,6 +489,11 @@ type Command struct {
// An action to execute after any subcommands are run, but after the subcommand has finished
// It is run even if Action() panics
After AfterFunc `json:"-"`
// An action to validate arguments before the command is run. If non-nil, it
// is called before Before and Action. If the current command does not set
// ArgValidator, the nearest ancestor that does is used instead.
// Returning a non-nil error short-circuits the command.
ArgValidator ArgValidatorFunc `json:"-"`
// The function to call when this command is invoked
Action ActionFunc `json:"-"`
// Execute this function if the proper command cannot be found
Expand Down
12 changes: 12 additions & 0 deletions testdata/godoc-v3.x.txt
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,13 @@ type AfterFunc func(context.Context, *Command) error
AfterFunc is an action that executes after any subcommands are run and have
finished. The AfterFunc is run even if Action() panics.

type ArgValidatorFunc func(context.Context, *Command) error
ArgValidatorFunc is an action to validate arguments before the command
is run. If non-nil, it is called before the command's After and Action
functions. Returning a non-nil error short-circuits the command and
propagates as the exit error. If the current command does not set
ArgValidator, the nearest ancestor that does is used instead.

type Args interface {
// Get returns the nth argument, or else a blank string
Get(n int) string
Expand Down Expand Up @@ -482,6 +489,11 @@ type Command struct {
// An action to execute after any subcommands are run, but after the subcommand has finished
// It is run even if Action() panics
After AfterFunc `json:"-"`
// An action to validate arguments before the command is run. If non-nil, it
// is called before Before and Action. If the current command does not set
// ArgValidator, the nearest ancestor that does is used instead.
// Returning a non-nil error short-circuits the command.
ArgValidator ArgValidatorFunc `json:"-"`
// The function to call when this command is invoked
Action ActionFunc `json:"-"`
// Execute this function if the proper command cannot be found
Expand Down