diff --git a/mage/args_test.go b/mage/args_test.go index d8ef8abd..42616d7f 100644 --- a/mage/args_test.go +++ b/mage/args_test.go @@ -179,6 +179,115 @@ Aliases: speak } } +func TestVariadicNoArgs(t *testing.T) { + stderr := &bytes.Buffer{} + stdout := &bytes.Buffer{} + inv := Invocation{ + Dir: "./testdata/args", + Stderr: stderr, + Stdout: stdout, + Args: []string{"greet"}, + } + code := Invoke(inv) + if code != 0 { + t.Log("stderr:", stderr) + t.Log("stdout:", stdout) + t.Fatalf("expected code 0, but got %v", code) + } + actual := stdout.String() + if actual != "" { + t.Fatalf("expected empty output, got %q", actual) + } +} + +func TestVariadicOneArg(t *testing.T) { + stderr := &bytes.Buffer{} + stdout := &bytes.Buffer{} + inv := Invocation{ + Dir: "./testdata/args", + Stderr: stderr, + Stdout: stdout, + Args: []string{"greet", "Alice"}, + } + code := Invoke(inv) + if code != 0 { + t.Log("stderr:", stderr) + t.Log("stdout:", stdout) + t.Fatalf("expected code 0, but got %v", code) + } + actual := stdout.String() + expected := "Hello, Alice\n" + if actual != expected { + t.Fatalf("output is not expected:\n%q", actual) + } +} + +func TestVariadicMultipleArgs(t *testing.T) { + stderr := &bytes.Buffer{} + stdout := &bytes.Buffer{} + inv := Invocation{ + Dir: "./testdata/args", + Stderr: stderr, + Stdout: stdout, + Args: []string{"greet", "Alice", "Bob"}, + } + code := Invoke(inv) + if code != 0 { + t.Log("stderr:", stderr) + t.Log("stdout:", stdout) + t.Fatalf("expected code 0, but got %v", code) + } + actual := stdout.String() + expected := "Hello, Alice\nHello, Bob\n" + if actual != expected { + t.Fatalf("output is not expected:\n%q", actual) + } +} + +func TestVariadicMissingRequired(t *testing.T) { + stderr := &bytes.Buffer{} + stdout := &bytes.Buffer{} + inv := Invocation{ + Dir: "./testdata/args", + Stderr: stderr, + Stdout: stdout, + Args: []string{"tag"}, + } + code := Invoke(inv) + if code != 2 { + t.Log("stderr:", stderr) + t.Log("stdout:", stdout) + t.Fatalf("expected code 2, but got %v", code) + } + actual := stderr.String() + expected := "not enough arguments for target \"Tag\", expected at least 1, got 0\n" + if actual != expected { + t.Fatalf("output is not expected:\n%q", actual) + } +} + +func TestVariadicWithRequired(t *testing.T) { + stderr := &bytes.Buffer{} + stdout := &bytes.Buffer{} + inv := Invocation{ + Dir: "./testdata/args", + Stderr: stderr, + Stdout: stdout, + Args: []string{"tag", "myimage", "v1", "latest"}, + } + code := Invoke(inv) + if code != 0 { + t.Log("stderr:", stderr) + t.Log("stdout:", stdout) + t.Fatalf("expected code 0, but got %v", code) + } + actual := stdout.String() + expected := "myimage\nv1\nlatest\n" + if actual != expected { + t.Fatalf("output is not expected:\n%q", actual) + } +} + func TestMgF(t *testing.T) { stderr := &bytes.Buffer{} stdout := &bytes.Buffer{} diff --git a/mage/template.go b/mage/template.go index b3466eaa..b5c0cebb 100644 --- a/mage/template.go +++ b/mage/template.go @@ -450,11 +450,44 @@ Options: switch _strings.ToLower(target) { {{range .Funcs }} case "{{lower .TargetName}}": + {{- if .IsVariadic}} + expected := x + {{len .Args}} - 1 + {{- else}} expected := x + {{len .Args}} + {{- end}} if expected > len(args.Args) { // note that expected and args at this point include the arg for the target itself // so we subtract 1 here to show the number of args without the target. + {{- if .IsVariadic}} + logger.Printf("not enough arguments for target \"{{.TargetName}}\", expected at least %v, got %v\n", expected-1, len(args.Args)-1) + {{- else}} logger.Printf("not enough arguments for target \"{{.TargetName}}\", expected %v, got %v\n", expected-1, len(args.Args)-1) + {{- end}} + os.Exit(2) + } + if args.Verbose { + logger.Println("Running target:", "{{.TargetName}}") + } + {{.ExecCode}} + handleError(logger, ret) + {{- end}} + {{range .Imports}} + {{$imp := .}} + {{range .Info.Funcs }} + case "{{lower .TargetName}}": + {{- if .IsVariadic}} + expected := x + {{len .Args}} - 1 + {{- else}} + expected := x + {{len .Args}} + {{- end}} + if expected > len(args.Args) { + // note that expected and args at this point include the arg for the target itself + // so we subtract 1 here to show the number of args without the target. + {{- if .IsVariadic}} + logger.Printf("not enough arguments for target \"{{.TargetName}}\", expected at least %v, got %v\n", expected-1, len(args.Args)-1) + {{- else}} + logger.Printf("not enough arguments for target \"{{.TargetName}}\", expected %v, got %v\n", expected-1, len(args.Args)-1) + {{- end}} os.Exit(2) } if args.Verbose { @@ -463,23 +496,6 @@ Options: {{.ExecCode}} handleError(logger, ret) {{- end}} - {{range .Imports}} - {{$imp := .}} - {{range .Info.Funcs }} - case "{{lower .TargetName}}": - expected := x + {{len .Args}} - if expected > len(args.Args) { - // note that expected and args at this point include the arg for the target itself - // so we subtract 1 here to show the number of args without the target. - logger.Printf("not enough arguments for target \"{{.TargetName}}\", expected %v, got %v\n", expected-1, len(args.Args)-1) - os.Exit(2) - } - if args.Verbose { - logger.Println("Running target:", "{{.TargetName}}") - } - {{.ExecCode}} - handleError(logger, ret) - {{- end}} {{- end}} default: logger.Printf("Unknown target specified: %q\n", target) diff --git a/mage/testdata/args/magefile.go b/mage/testdata/args/magefile.go index b74483c5..e8bbb8d0 100644 --- a/mage/testdata/args/magefile.go +++ b/mage/testdata/args/magefile.go @@ -54,3 +54,19 @@ func HasDep() { func DoubleIt(f float64) { fmt.Printf("%.1f * 2 = %.1f\n", f, f*2) } + +// Greet says hello to each provided name. +func Greet(names ...string) { + for _, n := range names { + fmt.Println("Hello,", n) + } +} + +// Tag prints an image name with optional extra labels. +func Tag(image string, labels ...string) error { + fmt.Println(image) + for _, l := range labels { + fmt.Println(l) + } + return nil +} diff --git a/parse/parse.go b/parse/parse.go index 8d67bb13..5b3ed606 100644 --- a/parse/parse.go +++ b/parse/parse.go @@ -51,6 +51,7 @@ type Function struct { Synopsis string Comment string Args []Arg + IsVariadic bool } var _ sort.Interface = (Functions)(nil) @@ -116,43 +117,101 @@ func (f Function) ExecCode() string { var parseargs string for x, arg := range f.Args { - switch arg.Type { - case "string": - parseargs += fmt.Sprintf(` - arg%d := args.Args[x] - x++`, x) - case "int": - parseargs += fmt.Sprintf(` - arg%d, err := strconv.Atoi(args.Args[x]) + isLastVariadic := f.IsVariadic && x == len(f.Args)-1 + if isLastVariadic { + switch arg.Type { + case "string": + parseargs += fmt.Sprintf(` + arg%d := args.Args[x:] + x = len(args.Args)`, x) + case "int": + parseargs += fmt.Sprintf(` + var arg%d []int + for _, _s := range args.Args[x:] { + _v, err := strconv.Atoi(_s) if err != nil { - logger.Printf("can't convert argument %%q to int\n", args.Args[x]) + logger.Printf("can't convert argument %%q to int\n", _s) os.Exit(2) } - x++`, x) - case "float64": - parseargs += fmt.Sprintf(` - arg%d, err := strconv.ParseFloat(args.Args[x], 64) + arg%d = append(arg%d, _v) + } + x = len(args.Args)`, x, x, x) + case "float64": + parseargs += fmt.Sprintf(` + var arg%d []float64 + for _, _s := range args.Args[x:] { + _v, err := strconv.ParseFloat(_s, 64) if err != nil { - logger.Printf("can't convert argument %%q to float64\n", args.Args[x]) + logger.Printf("can't convert argument %%q to float64\n", _s) os.Exit(2) } - x++`, x) - case "bool": - parseargs += fmt.Sprintf(` - arg%d, err := strconv.ParseBool(args.Args[x]) + arg%d = append(arg%d, _v) + } + x = len(args.Args)`, x, x, x) + case "bool": + parseargs += fmt.Sprintf(` + var arg%d []bool + for _, _s := range args.Args[x:] { + _v, err := strconv.ParseBool(_s) if err != nil { - logger.Printf("can't convert argument %%q to bool\n", args.Args[x]) + logger.Printf("can't convert argument %%q to bool\n", _s) os.Exit(2) } - x++`, x) - case "time.Duration": - parseargs += fmt.Sprintf(` - arg%d, err := time.ParseDuration(args.Args[x]) + arg%d = append(arg%d, _v) + } + x = len(args.Args)`, x, x, x) + case "time.Duration": + parseargs += fmt.Sprintf(` + var arg%d []time.Duration + for _, _s := range args.Args[x:] { + _v, err := time.ParseDuration(_s) if err != nil { - logger.Printf("can't convert argument %%q to time.Duration\n", args.Args[x]) + logger.Printf("can't convert argument %%q to time.Duration\n", _s) os.Exit(2) } - x++`, x) + arg%d = append(arg%d, _v) + } + x = len(args.Args)`, x, x, x) + } + } else { + switch arg.Type { + case "string": + parseargs += fmt.Sprintf(` + arg%d := args.Args[x] + x++`, x) + case "int": + parseargs += fmt.Sprintf(` + arg%d, err := strconv.Atoi(args.Args[x]) + if err != nil { + logger.Printf("can't convert argument %%q to int\n", args.Args[x]) + os.Exit(2) + } + x++`, x) + case "float64": + parseargs += fmt.Sprintf(` + arg%d, err := strconv.ParseFloat(args.Args[x], 64) + if err != nil { + logger.Printf("can't convert argument %%q to float64\n", args.Args[x]) + os.Exit(2) + } + x++`, x) + case "bool": + parseargs += fmt.Sprintf(` + arg%d, err := strconv.ParseBool(args.Args[x]) + if err != nil { + logger.Printf("can't convert argument %%q to bool\n", args.Args[x]) + os.Exit(2) + } + x++`, x) + case "time.Duration": + parseargs += fmt.Sprintf(` + arg%d, err := time.ParseDuration(args.Args[x]) + if err != nil { + logger.Printf("can't convert argument %%q to time.Duration\n", args.Args[x]) + os.Exit(2) + } + x++`, x) + } } } @@ -168,7 +227,11 @@ func (f Function) ExecCode() string { args = append(args, "ctx") } for x := 0; x < len(f.Args); x++ { - args = append(args, fmt.Sprintf("arg%d", x)) + if f.IsVariadic && x == len(f.Args)-1 { + args = append(args, fmt.Sprintf("arg%d...", x)) + } else { + args = append(args, fmt.Sprintf("arg%d", x)) + } } out += strings.Join(args, ", ") out += ")" @@ -835,15 +898,23 @@ func funcType(ft *ast.FuncType) (*Function, error) { } for ; x < len(ft.Params.List); x++ { param := ft.Params.List[x] - t := fmt.Sprint(param.Type) + typeNode := param.Type + isVariadic := false + if ellipsis, ok := param.Type.(*ast.Ellipsis); ok { + isVariadic = true + typeNode = ellipsis.Elt + } + t := fmt.Sprint(typeNode) typ, ok := argTypes[t] if !ok { return nil, fmt.Errorf("unsupported argument type: %s", t) } - // support for foo, bar string for _, name := range param.Names { f.Args = append(f.Args, Arg{Name: name.Name, Type: typ}) } + if isVariadic { + f.IsVariadic = true + } } return f, nil }