Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions mage/args_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,115 @@ Aliases: speak
}
}

func TestVariadicNoArgs(t *testing.T) {
stderr := &bytes.Buffer{}
stdout := &bytes.Buffer{}
inv := Invocation{
Dir: "./testdata/args",
Stderr: stderr,
Stdout: stdout,
Args: []string{"greet"},
}
code := Invoke(inv)
if code != 0 {
t.Log("stderr:", stderr)
t.Log("stdout:", stdout)
t.Fatalf("expected code 0, but got %v", code)
}
actual := stdout.String()
if actual != "" {
t.Fatalf("expected empty output, got %q", actual)
}
}

func TestVariadicOneArg(t *testing.T) {
stderr := &bytes.Buffer{}
stdout := &bytes.Buffer{}
inv := Invocation{
Dir: "./testdata/args",
Stderr: stderr,
Stdout: stdout,
Args: []string{"greet", "Alice"},
}
code := Invoke(inv)
if code != 0 {
t.Log("stderr:", stderr)
t.Log("stdout:", stdout)
t.Fatalf("expected code 0, but got %v", code)
}
actual := stdout.String()
expected := "Hello, Alice\n"
if actual != expected {
t.Fatalf("output is not expected:\n%q", actual)
}
}

func TestVariadicMultipleArgs(t *testing.T) {
stderr := &bytes.Buffer{}
stdout := &bytes.Buffer{}
inv := Invocation{
Dir: "./testdata/args",
Stderr: stderr,
Stdout: stdout,
Args: []string{"greet", "Alice", "Bob"},
}
code := Invoke(inv)
if code != 0 {
t.Log("stderr:", stderr)
t.Log("stdout:", stdout)
t.Fatalf("expected code 0, but got %v", code)
}
actual := stdout.String()
expected := "Hello, Alice\nHello, Bob\n"
if actual != expected {
t.Fatalf("output is not expected:\n%q", actual)
}
}

func TestVariadicMissingRequired(t *testing.T) {
stderr := &bytes.Buffer{}
stdout := &bytes.Buffer{}
inv := Invocation{
Dir: "./testdata/args",
Stderr: stderr,
Stdout: stdout,
Args: []string{"tag"},
}
code := Invoke(inv)
if code != 2 {
t.Log("stderr:", stderr)
t.Log("stdout:", stdout)
t.Fatalf("expected code 2, but got %v", code)
}
actual := stderr.String()
expected := "not enough arguments for target \"Tag\", expected at least 1, got 0\n"
if actual != expected {
t.Fatalf("output is not expected:\n%q", actual)
}
}

func TestVariadicWithRequired(t *testing.T) {
stderr := &bytes.Buffer{}
stdout := &bytes.Buffer{}
inv := Invocation{
Dir: "./testdata/args",
Stderr: stderr,
Stdout: stdout,
Args: []string{"tag", "myimage", "v1", "latest"},
}
code := Invoke(inv)
if code != 0 {
t.Log("stderr:", stderr)
t.Log("stdout:", stdout)
t.Fatalf("expected code 0, but got %v", code)
}
actual := stdout.String()
expected := "myimage\nv1\nlatest\n"
if actual != expected {
t.Fatalf("output is not expected:\n%q", actual)
}
}

func TestMgF(t *testing.T) {
stderr := &bytes.Buffer{}
stdout := &bytes.Buffer{}
Expand Down
50 changes: 33 additions & 17 deletions mage/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -450,11 +450,44 @@ Options:
switch _strings.ToLower(target) {
{{range .Funcs }}
case "{{lower .TargetName}}":
{{- if .IsVariadic}}
expected := x + {{len .Args}} - 1
{{- else}}
expected := x + {{len .Args}}
{{- end}}
if expected > len(args.Args) {
// note that expected and args at this point include the arg for the target itself
// so we subtract 1 here to show the number of args without the target.
{{- if .IsVariadic}}
logger.Printf("not enough arguments for target \"{{.TargetName}}\", expected at least %v, got %v\n", expected-1, len(args.Args)-1)
{{- else}}
logger.Printf("not enough arguments for target \"{{.TargetName}}\", expected %v, got %v\n", expected-1, len(args.Args)-1)
{{- end}}
os.Exit(2)
}
if args.Verbose {
logger.Println("Running target:", "{{.TargetName}}")
}
{{.ExecCode}}
handleError(logger, ret)
{{- end}}
{{range .Imports}}
{{$imp := .}}
{{range .Info.Funcs }}
case "{{lower .TargetName}}":
{{- if .IsVariadic}}
expected := x + {{len .Args}} - 1
{{- else}}
expected := x + {{len .Args}}
{{- end}}
if expected > len(args.Args) {
// note that expected and args at this point include the arg for the target itself
// so we subtract 1 here to show the number of args without the target.
{{- if .IsVariadic}}
logger.Printf("not enough arguments for target \"{{.TargetName}}\", expected at least %v, got %v\n", expected-1, len(args.Args)-1)
{{- else}}
logger.Printf("not enough arguments for target \"{{.TargetName}}\", expected %v, got %v\n", expected-1, len(args.Args)-1)
{{- end}}
os.Exit(2)
}
if args.Verbose {
Expand All @@ -463,23 +496,6 @@ Options:
{{.ExecCode}}
handleError(logger, ret)
{{- end}}
{{range .Imports}}
{{$imp := .}}
{{range .Info.Funcs }}
case "{{lower .TargetName}}":
expected := x + {{len .Args}}
if expected > len(args.Args) {
// note that expected and args at this point include the arg for the target itself
// so we subtract 1 here to show the number of args without the target.
logger.Printf("not enough arguments for target \"{{.TargetName}}\", expected %v, got %v\n", expected-1, len(args.Args)-1)
os.Exit(2)
}
if args.Verbose {
logger.Println("Running target:", "{{.TargetName}}")
}
{{.ExecCode}}
handleError(logger, ret)
{{- end}}
{{- end}}
default:
logger.Printf("Unknown target specified: %q\n", target)
Expand Down
16 changes: 16 additions & 0 deletions mage/testdata/args/magefile.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,19 @@ func HasDep() {
func DoubleIt(f float64) {
fmt.Printf("%.1f * 2 = %.1f\n", f, f*2)
}

// Greet says hello to each provided name.
func Greet(names ...string) {
for _, n := range names {
fmt.Println("Hello,", n)
}
}

// Tag prints an image name with optional extra labels.
func Tag(image string, labels ...string) error {
fmt.Println(image)
for _, l := range labels {
fmt.Println(l)
}
return nil
}
127 changes: 99 additions & 28 deletions parse/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ type Function struct {
Synopsis string
Comment string
Args []Arg
IsVariadic bool
}

var _ sort.Interface = (Functions)(nil)
Expand Down Expand Up @@ -116,43 +117,101 @@ func (f Function) ExecCode() string {

var parseargs string
for x, arg := range f.Args {
switch arg.Type {
case "string":
parseargs += fmt.Sprintf(`
arg%d := args.Args[x]
x++`, x)
case "int":
parseargs += fmt.Sprintf(`
arg%d, err := strconv.Atoi(args.Args[x])
isLastVariadic := f.IsVariadic && x == len(f.Args)-1
if isLastVariadic {
switch arg.Type {
case "string":
parseargs += fmt.Sprintf(`
arg%d := args.Args[x:]
x = len(args.Args)`, x)
case "int":
parseargs += fmt.Sprintf(`
var arg%d []int
for _, _s := range args.Args[x:] {
_v, err := strconv.Atoi(_s)
if err != nil {
logger.Printf("can't convert argument %%q to int\n", args.Args[x])
logger.Printf("can't convert argument %%q to int\n", _s)
os.Exit(2)
}
x++`, x)
case "float64":
parseargs += fmt.Sprintf(`
arg%d, err := strconv.ParseFloat(args.Args[x], 64)
arg%d = append(arg%d, _v)
}
x = len(args.Args)`, x, x, x)
case "float64":
parseargs += fmt.Sprintf(`
var arg%d []float64
for _, _s := range args.Args[x:] {
_v, err := strconv.ParseFloat(_s, 64)
if err != nil {
logger.Printf("can't convert argument %%q to float64\n", args.Args[x])
logger.Printf("can't convert argument %%q to float64\n", _s)
os.Exit(2)
}
x++`, x)
case "bool":
parseargs += fmt.Sprintf(`
arg%d, err := strconv.ParseBool(args.Args[x])
arg%d = append(arg%d, _v)
}
x = len(args.Args)`, x, x, x)
case "bool":
parseargs += fmt.Sprintf(`
var arg%d []bool
for _, _s := range args.Args[x:] {
_v, err := strconv.ParseBool(_s)
if err != nil {
logger.Printf("can't convert argument %%q to bool\n", args.Args[x])
logger.Printf("can't convert argument %%q to bool\n", _s)
os.Exit(2)
}
x++`, x)
case "time.Duration":
parseargs += fmt.Sprintf(`
arg%d, err := time.ParseDuration(args.Args[x])
arg%d = append(arg%d, _v)
}
x = len(args.Args)`, x, x, x)
case "time.Duration":
parseargs += fmt.Sprintf(`
var arg%d []time.Duration
for _, _s := range args.Args[x:] {
_v, err := time.ParseDuration(_s)
if err != nil {
logger.Printf("can't convert argument %%q to time.Duration\n", args.Args[x])
logger.Printf("can't convert argument %%q to time.Duration\n", _s)
os.Exit(2)
}
x++`, x)
arg%d = append(arg%d, _v)
}
x = len(args.Args)`, x, x, x)
}
} else {
switch arg.Type {
case "string":
parseargs += fmt.Sprintf(`
arg%d := args.Args[x]
x++`, x)
case "int":
parseargs += fmt.Sprintf(`
arg%d, err := strconv.Atoi(args.Args[x])
if err != nil {
logger.Printf("can't convert argument %%q to int\n", args.Args[x])
os.Exit(2)
}
x++`, x)
case "float64":
parseargs += fmt.Sprintf(`
arg%d, err := strconv.ParseFloat(args.Args[x], 64)
if err != nil {
logger.Printf("can't convert argument %%q to float64\n", args.Args[x])
os.Exit(2)
}
x++`, x)
case "bool":
parseargs += fmt.Sprintf(`
arg%d, err := strconv.ParseBool(args.Args[x])
if err != nil {
logger.Printf("can't convert argument %%q to bool\n", args.Args[x])
os.Exit(2)
}
x++`, x)
case "time.Duration":
parseargs += fmt.Sprintf(`
arg%d, err := time.ParseDuration(args.Args[x])
if err != nil {
logger.Printf("can't convert argument %%q to time.Duration\n", args.Args[x])
os.Exit(2)
}
x++`, x)
}
}
}

Expand All @@ -168,7 +227,11 @@ func (f Function) ExecCode() string {
args = append(args, "ctx")
}
for x := 0; x < len(f.Args); x++ {
args = append(args, fmt.Sprintf("arg%d", x))
if f.IsVariadic && x == len(f.Args)-1 {
args = append(args, fmt.Sprintf("arg%d...", x))
} else {
args = append(args, fmt.Sprintf("arg%d", x))
}
}
out += strings.Join(args, ", ")
out += ")"
Expand Down Expand Up @@ -835,15 +898,23 @@ func funcType(ft *ast.FuncType) (*Function, error) {
}
for ; x < len(ft.Params.List); x++ {
param := ft.Params.List[x]
t := fmt.Sprint(param.Type)
typeNode := param.Type
isVariadic := false
if ellipsis, ok := param.Type.(*ast.Ellipsis); ok {
isVariadic = true
typeNode = ellipsis.Elt
}
t := fmt.Sprint(typeNode)
typ, ok := argTypes[t]
if !ok {
return nil, fmt.Errorf("unsupported argument type: %s", t)
}
// support for foo, bar string
for _, name := range param.Names {
f.Args = append(f.Args, Arg{Name: name.Name, Type: typ})
}
if isVariadic {
f.IsVariadic = true
}
}
return f, nil
}
Expand Down