diff --git a/bindparam.go b/bindparam.go index cd2fdd5..535e11f 100644 --- a/bindparam.go +++ b/bindparam.go @@ -69,6 +69,9 @@ type BindStyledParameterOptions struct { // When set to "byte" and the destination is []byte, the value is // base64-decoded rather than treated as a generic slice. Format string + // AllowReserved, when true, indicates that the parameter value may + // contain RFC 3986 reserved characters without percent-encoding. + AllowReserved bool } // BindStyledParameterWithOptions binds a parameter as described in the Path Parameters @@ -346,6 +349,9 @@ type BindQueryParameterOptions struct { // When set to "byte" and the destination is []byte, the value is // base64-decoded rather than treated as a generic slice. Format string + // AllowReserved, when true, indicates that the parameter value may + // contain RFC 3986 reserved characters without percent-encoding. + AllowReserved bool } // BindQueryParameterWithOptions works like BindQueryParameter with additional options. diff --git a/styleparam.go b/styleparam.go index 389aa2a..2800733 100644 --- a/styleparam.go +++ b/styleparam.go @@ -72,6 +72,10 @@ type StyleParamOptions struct { Format string // Required indicates whether the parameter is required. Required bool + // AllowReserved, when true, prevents percent-encoding of RFC 3986 + // reserved characters in query parameter values. Per the OpenAPI 3.x + // spec, this only applies to query parameters. + AllowReserved bool } // StyleParamWithOptions serializes a Go value into an OpenAPI-styled parameter @@ -105,7 +109,7 @@ func StyleParamWithOptions(style string, explode bool, paramName string, value i return "", fmt.Errorf("error marshaling '%s' as text: %w", value, err) } - return stylePrimitive(style, explode, paramName, opts.ParamLocation, string(b)) + return stylePrimitive(style, explode, paramName, opts.ParamLocation, opts.AllowReserved, string(b)) } } @@ -113,24 +117,24 @@ func StyleParamWithOptions(style string, explode bool, paramName string, value i case reflect.Slice: if opts.Format == "byte" && isByteSlice(t) { encoded := base64.StdEncoding.EncodeToString(v.Bytes()) - return stylePrimitive(style, explode, paramName, opts.ParamLocation, encoded) + return stylePrimitive(style, explode, paramName, opts.ParamLocation, opts.AllowReserved, encoded) } n := v.Len() sliceVal := make([]interface{}, n) for i := 0; i < n; i++ { sliceVal[i] = v.Index(i).Interface() } - return styleSlice(style, explode, paramName, opts.ParamLocation, sliceVal) + return styleSlice(style, explode, paramName, opts.ParamLocation, opts.AllowReserved, sliceVal) case reflect.Struct: - return styleStruct(style, explode, paramName, opts.ParamLocation, value) + return styleStruct(style, explode, paramName, opts.ParamLocation, opts.AllowReserved, value) case reflect.Map: - return styleMap(style, explode, paramName, opts.ParamLocation, value) + return styleMap(style, explode, paramName, opts.ParamLocation, opts.AllowReserved, value) default: - return stylePrimitive(style, explode, paramName, opts.ParamLocation, value) + return stylePrimitive(style, explode, paramName, opts.ParamLocation, opts.AllowReserved, value) } } -func styleSlice(style string, explode bool, paramName string, paramLocation ParamLocation, values []interface{}) (string, error) { +func styleSlice(style string, explode bool, paramName string, paramLocation ParamLocation, allowReserved bool, values []interface{}) (string, error) { if style == "deepObject" { if !explode { return "", errors.New("deepObjects must be exploded") @@ -141,6 +145,8 @@ func styleSlice(style string, explode bool, paramName string, paramLocation Para var prefix string var separator string + escapedName := escapeParameterName(paramName, paramLocation) + switch style { case "simple": separator = "," @@ -152,28 +158,28 @@ func styleSlice(style string, explode bool, paramName string, paramLocation Para separator = "," } case "matrix": - prefix = fmt.Sprintf(";%s=", paramName) + prefix = fmt.Sprintf(";%s=", escapedName) if explode { separator = prefix } else { separator = "," } case "form": - prefix = fmt.Sprintf("%s=", paramName) + prefix = fmt.Sprintf("%s=", escapedName) if explode { separator = "&" + prefix } else { separator = "," } case "spaceDelimited": - prefix = fmt.Sprintf("%s=", paramName) + prefix = fmt.Sprintf("%s=", escapedName) if explode { separator = "&" + prefix } else { separator = " " } case "pipeDelimited": - prefix = fmt.Sprintf("%s=", paramName) + prefix = fmt.Sprintf("%s=", escapedName) if explode { separator = "&" + prefix } else { @@ -189,7 +195,7 @@ func styleSlice(style string, explode bool, paramName string, paramLocation Para parts := make([]string, len(values)) for i, v := range values { part, err = primitiveToString(v) - part = escapeParameterString(part, paramLocation) + part = escapeParameterString(part, paramLocation, allowReserved) parts[i] = part if err != nil { return "", fmt.Errorf("error formatting '%s': %w", paramName, err) @@ -236,9 +242,9 @@ func marshalKnownTypes(value interface{}) (string, bool) { return "", false } -func styleStruct(style string, explode bool, paramName string, paramLocation ParamLocation, value interface{}) (string, error) { +func styleStruct(style string, explode bool, paramName string, paramLocation ParamLocation, allowReserved bool, value interface{}) (string, error) { if timeVal, ok := marshalKnownTypes(value); ok { - styledVal, err := stylePrimitive(style, explode, paramName, paramLocation, timeVal) + styledVal, err := stylePrimitive(style, explode, paramName, paramLocation, allowReserved, timeVal) if err != nil { return "", fmt.Errorf("failed to style time: %w", err) } @@ -266,7 +272,10 @@ func styleStruct(style string, explode bool, paramName string, paramLocation Par if err != nil { return "", fmt.Errorf("failed to unmarshal JSON: %w", err) } - s, err := StyleParamWithLocation(style, explode, paramName, paramLocation, i2) + s, err := StyleParamWithOptions(style, explode, paramName, i2, StyleParamOptions{ + ParamLocation: paramLocation, + AllowReserved: allowReserved, + }) if err != nil { return "", fmt.Errorf("error style JSON structure: %w", err) } @@ -305,10 +314,10 @@ func styleStruct(style string, explode bool, paramName string, paramLocation Par fieldDict[fieldName] = str } - return processFieldDict(style, explode, paramName, paramLocation, fieldDict) + return processFieldDict(style, explode, paramName, paramLocation, allowReserved, fieldDict) } -func styleMap(style string, explode bool, paramName string, paramLocation ParamLocation, value interface{}) (string, error) { +func styleMap(style string, explode bool, paramName string, paramLocation ParamLocation, allowReserved bool, value interface{}) (string, error) { if style == "deepObject" { if !explode { return "", errors.New("deepObjects must be exploded") @@ -325,10 +334,10 @@ func styleMap(style string, explode bool, paramName string, paramLocation ParamL } fieldDict[fieldName.String()] = str } - return processFieldDict(style, explode, paramName, paramLocation, fieldDict) + return processFieldDict(style, explode, paramName, paramLocation, allowReserved, fieldDict) } -func processFieldDict(style string, explode bool, paramName string, paramLocation ParamLocation, fieldDict map[string]string) (string, error) { +func processFieldDict(style string, explode bool, paramName string, paramLocation ParamLocation, allowReserved bool, fieldDict map[string]string) (string, error) { var parts []string // This works for everything except deepObject. We'll handle that one @@ -336,18 +345,20 @@ func processFieldDict(style string, explode bool, paramName string, paramLocatio if style != "deepObject" { if explode { for _, k := range sortedKeys(fieldDict) { - v := escapeParameterString(fieldDict[k], paramLocation) + v := escapeParameterString(fieldDict[k], paramLocation, allowReserved) parts = append(parts, k+"="+v) } } else { for _, k := range sortedKeys(fieldDict) { - v := escapeParameterString(fieldDict[k], paramLocation) + v := escapeParameterString(fieldDict[k], paramLocation, allowReserved) parts = append(parts, k) parts = append(parts, v) } } } + escapedName := escapeParameterName(paramName, paramLocation) + var prefix string var separator string @@ -367,13 +378,13 @@ func processFieldDict(style string, explode bool, paramName string, paramLocatio prefix = ";" } else { separator = "," - prefix = fmt.Sprintf(";%s=", paramName) + prefix = fmt.Sprintf(";%s=", escapedName) } case "form": if explode { separator = "&" } else { - prefix = fmt.Sprintf("%s=", paramName) + prefix = fmt.Sprintf("%s=", escapedName) separator = "," } case "deepObject": @@ -383,7 +394,7 @@ func processFieldDict(style string, explode bool, paramName string, paramLocatio } for _, k := range sortedKeys(fieldDict) { v := fieldDict[k] - part := fmt.Sprintf("%s[%s]=%s", paramName, k, v) + part := fmt.Sprintf("%s[%s]=%s", escapedName, k, v) parts = append(parts, part) } separator = "&" @@ -395,25 +406,27 @@ func processFieldDict(style string, explode bool, paramName string, paramLocatio return prefix + strings.Join(parts, separator), nil } -func stylePrimitive(style string, explode bool, paramName string, paramLocation ParamLocation, value interface{}) (string, error) { +func stylePrimitive(style string, explode bool, paramName string, paramLocation ParamLocation, allowReserved bool, value interface{}) (string, error) { strVal, err := primitiveToString(value) if err != nil { return "", err } + escapedName := escapeParameterName(paramName, paramLocation) + var prefix string switch style { case "simple": case "label": prefix = "." case "matrix": - prefix = fmt.Sprintf(";%s=", paramName) + prefix = fmt.Sprintf(";%s=", escapedName) case "form": - prefix = fmt.Sprintf("%s=", paramName) + prefix = fmt.Sprintf("%s=", escapedName) default: return "", fmt.Errorf("unsupported style '%s'", style) } - return prefix + escapeParameterString(strVal, paramLocation), nil + return prefix + escapeParameterString(strVal, paramLocation, allowReserved), nil } // Converts a primitive value to a string. We need to do this based on the @@ -486,12 +499,26 @@ func primitiveToString(value interface{}) (string, error) { return output, nil } -// escapeParameterString escapes a parameter value bas on the location of that parameter. -// Query params and path params need different kinds of escaping, while header -// and cookie params seem not to need escaping. -func escapeParameterString(value string, paramLocation ParamLocation) string { +// escapeParameterName escapes a parameter name for use in query strings and +// paths. This ensures characters like [] in parameter names (e.g. user_ids[]) +// are properly percent-encoded per RFC 3986. +func escapeParameterName(name string, paramLocation ParamLocation) string { + // Parameter names should always be encoded regardless of allowReserved, + // which only applies to values per the OpenAPI spec. + return escapeParameterString(name, paramLocation, false) +} + +// escapeParameterString escapes a parameter value based on the location of +// that parameter. Query params and path params need different kinds of +// escaping, while header and cookie params seem not to need escaping. +// When allowReserved is true and the location is query, RFC 3986 reserved +// characters are left unencoded per the OpenAPI allowReserved specification. +func escapeParameterString(value string, paramLocation ParamLocation, allowReserved bool) string { switch paramLocation { case ParamLocationQuery: + if allowReserved { + return escapeQueryAllowReserved(value) + } return url.QueryEscape(value) case ParamLocationPath: return url.PathEscape(value) @@ -499,3 +526,33 @@ func escapeParameterString(value string, paramLocation ParamLocation) string { return value } } + +// escapeQueryAllowReserved percent-encodes a query parameter value while +// leaving RFC 3986 reserved characters (:/?#[]@!$&'()*+,;=) unencoded, as +// specified by OpenAPI's allowReserved parameter option. Only characters that +// are neither unreserved nor reserved are encoded (e.g., spaces, control +// characters, non-ASCII). +func escapeQueryAllowReserved(value string) string { + // RFC 3986 reserved characters that should NOT be encoded when + // allowReserved is true. + const reserved = `:/?#[]@!$&'()*+,;=` + + var buf strings.Builder + for _, b := range []byte(value) { + if isUnreserved(b) || strings.IndexByte(reserved, b) >= 0 { + buf.WriteByte(b) + } else { + fmt.Fprintf(&buf, "%%%02X", b) + } + } + return buf.String() +} + +// isUnreserved reports whether the byte is an RFC 3986 unreserved character: +// ALPHA / DIGIT / "-" / "." / "_" / "~" +func isUnreserved(c byte) bool { + return (c >= 'A' && c <= 'Z') || + (c >= 'a' && c <= 'z') || + (c >= '0' && c <= '9') || + c == '-' || c == '.' || c == '_' || c == '~' +} diff --git a/styleparam_test.go b/styleparam_test.go index 4d0f3fe..3fe6a29 100644 --- a/styleparam_test.go +++ b/styleparam_test.go @@ -756,3 +756,140 @@ func TestIssue37(t *testing.T) { } } } + +func TestStyleParamAllowReserved(t *testing.T) { + opts := func(allowReserved bool) StyleParamOptions { + return StyleParamOptions{ + ParamLocation: ParamLocationQuery, + AllowReserved: allowReserved, + } + } + + t.Run("primitive with reserved chars", func(t *testing.T) { + // Semicolons and colons are RFC 3986 reserved characters. + value := "List(79988552,27056405)" + + result, err := StyleParamWithOptions("form", false, "ids", value, opts(false)) + assert.NoError(t, err) + assert.EqualValues(t, "ids=List%2879988552%2C27056405%29", result, "reserved chars should be encoded when allowReserved=false") + + result, err = StyleParamWithOptions("form", false, "ids", value, opts(true)) + assert.NoError(t, err) + assert.EqualValues(t, "ids=List(79988552,27056405)", result, "reserved chars should be preserved when allowReserved=true") + }) + + t.Run("primitive with colons and slashes", func(t *testing.T) { + value := "2020-01-01T22:00:00+02:00" + + result, err := StyleParamWithOptions("form", false, "ts", value, opts(false)) + assert.NoError(t, err) + assert.EqualValues(t, "ts=2020-01-01T22%3A00%3A00%2B02%3A00", result) + + result, err = StyleParamWithOptions("form", false, "ts", value, opts(true)) + assert.NoError(t, err) + assert.EqualValues(t, "ts=2020-01-01T22:00:00+02:00", result) + }) + + t.Run("array with reserved chars in values", func(t *testing.T) { + values := []string{"a;b", "c:d"} + + result, err := StyleParamWithOptions("form", false, "items", values, opts(false)) + assert.NoError(t, err) + assert.EqualValues(t, "items=a%3Bb,c%3Ad", result) + + result, err = StyleParamWithOptions("form", false, "items", values, opts(true)) + assert.NoError(t, err) + assert.EqualValues(t, "items=a;b,c:d", result) + }) + + t.Run("array exploded with reserved chars", func(t *testing.T) { + values := []string{"a;b", "c:d"} + + result, err := StyleParamWithOptions("form", true, "items", values, opts(false)) + assert.NoError(t, err) + assert.EqualValues(t, "items=a%3Bb&items=c%3Ad", result) + + result, err = StyleParamWithOptions("form", true, "items", values, opts(true)) + assert.NoError(t, err) + assert.EqualValues(t, "items=a;b&items=c:d", result) + }) + + t.Run("spaces still encoded with allowReserved", func(t *testing.T) { + value := "hello world" + + result, err := StyleParamWithOptions("form", false, "q", value, opts(true)) + assert.NoError(t, err) + assert.EqualValues(t, "q=hello%20world", result, "spaces should still be encoded even with allowReserved=true") + }) + + t.Run("allowReserved has no effect on non-query locations", func(t *testing.T) { + value := "a;b" + + // Path params should still encode regardless of allowReserved. + result, err := StyleParamWithOptions("simple", false, "id", value, StyleParamOptions{ + ParamLocation: ParamLocationPath, + AllowReserved: true, + }) + assert.NoError(t, err) + assert.EqualValues(t, "a%3Bb", result, "path params should always encode reserved chars") + }) + + t.Run("zero value preserves existing behavior", func(t *testing.T) { + value := "123;456" + + // Default (AllowReserved: false) should match existing behavior. + result, err := StyleParamWithOptions("form", false, "id", value, StyleParamOptions{ + ParamLocation: ParamLocationQuery, + }) + assert.NoError(t, err) + assert.EqualValues(t, "id=123%3B456", result) + }) +} + +func TestStyleParamNameEncoding(t *testing.T) { + opts := StyleParamOptions{ParamLocation: ParamLocationQuery} + + t.Run("brackets in param name are encoded for query", func(t *testing.T) { + result, err := StyleParamWithOptions("form", true, "user_ids[]", []string{"1", "100"}, opts) + assert.NoError(t, err) + assert.EqualValues(t, "user_ids%5B%5D=1&user_ids%5B%5D=100", result) + }) + + t.Run("brackets in param name non-exploded", func(t *testing.T) { + result, err := StyleParamWithOptions("form", false, "user_ids[]", []string{"1", "100"}, opts) + assert.NoError(t, err) + assert.EqualValues(t, "user_ids%5B%5D=1,100", result) + }) + + t.Run("brackets in primitive param name", func(t *testing.T) { + result, err := StyleParamWithOptions("form", false, "filter[name]", "foo", opts) + assert.NoError(t, err) + assert.EqualValues(t, "filter%5Bname%5D=foo", result) + }) + + t.Run("simple alphanumeric name unchanged", func(t *testing.T) { + result, err := StyleParamWithOptions("form", false, "color", "blue", opts) + assert.NoError(t, err) + assert.EqualValues(t, "color=blue", result) + }) + + t.Run("path param name not encoded", func(t *testing.T) { + // Path params use the name in matrix style prefix + result, err := StyleParamWithOptions("matrix", false, "id", "5", StyleParamOptions{ + ParamLocation: ParamLocationPath, + }) + assert.NoError(t, err) + assert.EqualValues(t, ";id=5", result) + }) + + t.Run("deepObject param name not yet encoded", func(t *testing.T) { + // NOTE: MarshalDeepObject handles its own serialization and does not + // currently encode param names. This documents the current behavior. + type Obj struct { + Name string `json:"name"` + } + result, err := StyleParamWithOptions("deepObject", true, "filter[]", Obj{Name: "foo"}, opts) + assert.NoError(t, err) + assert.EqualValues(t, "filter[][name]=foo", result) + }) +}