Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions bindparam.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ type BindStyledParameterOptions struct {
// When set to "byte" and the destination is []byte, the value is
// base64-decoded rather than treated as a generic slice.
Format string
// AllowReserved, when true, indicates that the parameter value may
// contain RFC 3986 reserved characters without percent-encoding.
AllowReserved bool
}

// BindStyledParameterWithOptions binds a parameter as described in the Path Parameters
Expand Down Expand Up @@ -346,6 +349,9 @@ type BindQueryParameterOptions struct {
// When set to "byte" and the destination is []byte, the value is
// base64-decoded rather than treated as a generic slice.
Format string
// AllowReserved, when true, indicates that the parameter value may
// contain RFC 3986 reserved characters without percent-encoding.
AllowReserved bool
}

// BindQueryParameterWithOptions works like BindQueryParameter with additional options.
Expand Down
121 changes: 89 additions & 32 deletions styleparam.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ type StyleParamOptions struct {
Format string
// Required indicates whether the parameter is required.
Required bool
// AllowReserved, when true, prevents percent-encoding of RFC 3986
// reserved characters in query parameter values. Per the OpenAPI 3.x
// spec, this only applies to query parameters.
AllowReserved bool
}

// StyleParamWithOptions serializes a Go value into an OpenAPI-styled parameter
Expand Down Expand Up @@ -105,32 +109,32 @@ func StyleParamWithOptions(style string, explode bool, paramName string, value i
return "", fmt.Errorf("error marshaling '%s' as text: %w", value, err)
}

return stylePrimitive(style, explode, paramName, opts.ParamLocation, string(b))
return stylePrimitive(style, explode, paramName, opts.ParamLocation, opts.AllowReserved, string(b))
}
}

switch t.Kind() {
case reflect.Slice:
if opts.Format == "byte" && isByteSlice(t) {
encoded := base64.StdEncoding.EncodeToString(v.Bytes())
return stylePrimitive(style, explode, paramName, opts.ParamLocation, encoded)
return stylePrimitive(style, explode, paramName, opts.ParamLocation, opts.AllowReserved, encoded)
}
n := v.Len()
sliceVal := make([]interface{}, n)
for i := 0; i < n; i++ {
sliceVal[i] = v.Index(i).Interface()
}
return styleSlice(style, explode, paramName, opts.ParamLocation, sliceVal)
return styleSlice(style, explode, paramName, opts.ParamLocation, opts.AllowReserved, sliceVal)
case reflect.Struct:
return styleStruct(style, explode, paramName, opts.ParamLocation, value)
return styleStruct(style, explode, paramName, opts.ParamLocation, opts.AllowReserved, value)
case reflect.Map:
return styleMap(style, explode, paramName, opts.ParamLocation, value)
return styleMap(style, explode, paramName, opts.ParamLocation, opts.AllowReserved, value)
default:
return stylePrimitive(style, explode, paramName, opts.ParamLocation, value)
return stylePrimitive(style, explode, paramName, opts.ParamLocation, opts.AllowReserved, value)
}
}

func styleSlice(style string, explode bool, paramName string, paramLocation ParamLocation, values []interface{}) (string, error) {
func styleSlice(style string, explode bool, paramName string, paramLocation ParamLocation, allowReserved bool, values []interface{}) (string, error) {
if style == "deepObject" {
if !explode {
return "", errors.New("deepObjects must be exploded")
Expand All @@ -141,6 +145,8 @@ func styleSlice(style string, explode bool, paramName string, paramLocation Para
var prefix string
var separator string

escapedName := escapeParameterName(paramName, paramLocation)

switch style {
case "simple":
separator = ","
Expand All @@ -152,28 +158,28 @@ func styleSlice(style string, explode bool, paramName string, paramLocation Para
separator = ","
}
case "matrix":
prefix = fmt.Sprintf(";%s=", paramName)
prefix = fmt.Sprintf(";%s=", escapedName)
if explode {
separator = prefix
} else {
separator = ","
}
case "form":
prefix = fmt.Sprintf("%s=", paramName)
prefix = fmt.Sprintf("%s=", escapedName)
if explode {
separator = "&" + prefix
} else {
separator = ","
}
case "spaceDelimited":
prefix = fmt.Sprintf("%s=", paramName)
prefix = fmt.Sprintf("%s=", escapedName)
if explode {
separator = "&" + prefix
} else {
separator = " "
}
case "pipeDelimited":
prefix = fmt.Sprintf("%s=", paramName)
prefix = fmt.Sprintf("%s=", escapedName)
if explode {
separator = "&" + prefix
} else {
Expand All @@ -189,7 +195,7 @@ func styleSlice(style string, explode bool, paramName string, paramLocation Para
parts := make([]string, len(values))
for i, v := range values {
part, err = primitiveToString(v)
part = escapeParameterString(part, paramLocation)
part = escapeParameterString(part, paramLocation, allowReserved)
parts[i] = part
if err != nil {
return "", fmt.Errorf("error formatting '%s': %w", paramName, err)
Expand Down Expand Up @@ -236,9 +242,9 @@ func marshalKnownTypes(value interface{}) (string, bool) {
return "", false
}

func styleStruct(style string, explode bool, paramName string, paramLocation ParamLocation, value interface{}) (string, error) {
func styleStruct(style string, explode bool, paramName string, paramLocation ParamLocation, allowReserved bool, value interface{}) (string, error) {
if timeVal, ok := marshalKnownTypes(value); ok {
styledVal, err := stylePrimitive(style, explode, paramName, paramLocation, timeVal)
styledVal, err := stylePrimitive(style, explode, paramName, paramLocation, allowReserved, timeVal)
if err != nil {
return "", fmt.Errorf("failed to style time: %w", err)
}
Expand Down Expand Up @@ -266,7 +272,10 @@ func styleStruct(style string, explode bool, paramName string, paramLocation Par
if err != nil {
return "", fmt.Errorf("failed to unmarshal JSON: %w", err)
}
s, err := StyleParamWithLocation(style, explode, paramName, paramLocation, i2)
s, err := StyleParamWithOptions(style, explode, paramName, i2, StyleParamOptions{
ParamLocation: paramLocation,
AllowReserved: allowReserved,
})
if err != nil {
return "", fmt.Errorf("error style JSON structure: %w", err)
}
Expand Down Expand Up @@ -305,10 +314,10 @@ func styleStruct(style string, explode bool, paramName string, paramLocation Par
fieldDict[fieldName] = str
}

return processFieldDict(style, explode, paramName, paramLocation, fieldDict)
return processFieldDict(style, explode, paramName, paramLocation, allowReserved, fieldDict)
}

func styleMap(style string, explode bool, paramName string, paramLocation ParamLocation, value interface{}) (string, error) {
func styleMap(style string, explode bool, paramName string, paramLocation ParamLocation, allowReserved bool, value interface{}) (string, error) {
if style == "deepObject" {
if !explode {
return "", errors.New("deepObjects must be exploded")
Expand All @@ -325,29 +334,31 @@ func styleMap(style string, explode bool, paramName string, paramLocation ParamL
}
fieldDict[fieldName.String()] = str
}
return processFieldDict(style, explode, paramName, paramLocation, fieldDict)
return processFieldDict(style, explode, paramName, paramLocation, allowReserved, fieldDict)
}

func processFieldDict(style string, explode bool, paramName string, paramLocation ParamLocation, fieldDict map[string]string) (string, error) {
func processFieldDict(style string, explode bool, paramName string, paramLocation ParamLocation, allowReserved bool, fieldDict map[string]string) (string, error) {
var parts []string

// This works for everything except deepObject. We'll handle that one
// separately.
if style != "deepObject" {
if explode {
for _, k := range sortedKeys(fieldDict) {
v := escapeParameterString(fieldDict[k], paramLocation)
v := escapeParameterString(fieldDict[k], paramLocation, allowReserved)
parts = append(parts, k+"="+v)
}
} else {
for _, k := range sortedKeys(fieldDict) {
v := escapeParameterString(fieldDict[k], paramLocation)
v := escapeParameterString(fieldDict[k], paramLocation, allowReserved)
parts = append(parts, k)
parts = append(parts, v)
}
}
}

escapedName := escapeParameterName(paramName, paramLocation)

var prefix string
var separator string

Expand All @@ -367,13 +378,13 @@ func processFieldDict(style string, explode bool, paramName string, paramLocatio
prefix = ";"
} else {
separator = ","
prefix = fmt.Sprintf(";%s=", paramName)
prefix = fmt.Sprintf(";%s=", escapedName)
}
case "form":
if explode {
separator = "&"
} else {
prefix = fmt.Sprintf("%s=", paramName)
prefix = fmt.Sprintf("%s=", escapedName)
separator = ","
}
case "deepObject":
Expand All @@ -383,7 +394,7 @@ func processFieldDict(style string, explode bool, paramName string, paramLocatio
}
for _, k := range sortedKeys(fieldDict) {
v := fieldDict[k]
part := fmt.Sprintf("%s[%s]=%s", paramName, k, v)
part := fmt.Sprintf("%s[%s]=%s", escapedName, k, v)
parts = append(parts, part)
}
separator = "&"
Expand All @@ -395,25 +406,27 @@ func processFieldDict(style string, explode bool, paramName string, paramLocatio
return prefix + strings.Join(parts, separator), nil
}

func stylePrimitive(style string, explode bool, paramName string, paramLocation ParamLocation, value interface{}) (string, error) {
func stylePrimitive(style string, explode bool, paramName string, paramLocation ParamLocation, allowReserved bool, value interface{}) (string, error) {
strVal, err := primitiveToString(value)
if err != nil {
return "", err
}

escapedName := escapeParameterName(paramName, paramLocation)

var prefix string
switch style {
case "simple":
case "label":
prefix = "."
case "matrix":
prefix = fmt.Sprintf(";%s=", paramName)
prefix = fmt.Sprintf(";%s=", escapedName)
case "form":
prefix = fmt.Sprintf("%s=", paramName)
prefix = fmt.Sprintf("%s=", escapedName)
default:
return "", fmt.Errorf("unsupported style '%s'", style)
}
return prefix + escapeParameterString(strVal, paramLocation), nil
return prefix + escapeParameterString(strVal, paramLocation, allowReserved), nil
}

// Converts a primitive value to a string. We need to do this based on the
Expand Down Expand Up @@ -486,16 +499,60 @@ func primitiveToString(value interface{}) (string, error) {
return output, nil
}

// escapeParameterString escapes a parameter value bas on the location of that parameter.
// Query params and path params need different kinds of escaping, while header
// and cookie params seem not to need escaping.
func escapeParameterString(value string, paramLocation ParamLocation) string {
// escapeParameterName escapes a parameter name for use in query strings and
// paths. This ensures characters like [] in parameter names (e.g. user_ids[])
// are properly percent-encoded per RFC 3986.
func escapeParameterName(name string, paramLocation ParamLocation) string {
// Parameter names should always be encoded regardless of allowReserved,
// which only applies to values per the OpenAPI spec.
return escapeParameterString(name, paramLocation, false)
}

// escapeParameterString escapes a parameter value based on the location of
// that parameter. Query params and path params need different kinds of
// escaping, while header and cookie params seem not to need escaping.
// When allowReserved is true and the location is query, RFC 3986 reserved
// characters are left unencoded per the OpenAPI allowReserved specification.
func escapeParameterString(value string, paramLocation ParamLocation, allowReserved bool) string {
switch paramLocation {
case ParamLocationQuery:
if allowReserved {
return escapeQueryAllowReserved(value)
}
return url.QueryEscape(value)
case ParamLocationPath:
return url.PathEscape(value)
default:
return value
}
}

// escapeQueryAllowReserved percent-encodes a query parameter value while
// leaving RFC 3986 reserved characters (:/?#[]@!$&'()*+,;=) unencoded, as
// specified by OpenAPI's allowReserved parameter option. Only characters that
// are neither unreserved nor reserved are encoded (e.g., spaces, control
// characters, non-ASCII).
func escapeQueryAllowReserved(value string) string {
// RFC 3986 reserved characters that should NOT be encoded when
// allowReserved is true.
const reserved = `:/?#[]@!$&'()*+,;=`

var buf strings.Builder
for _, b := range []byte(value) {
if isUnreserved(b) || strings.IndexByte(reserved, b) >= 0 {
buf.WriteByte(b)
} else {
fmt.Fprintf(&buf, "%%%02X", b)
}
}
return buf.String()
}

// isUnreserved reports whether the byte is an RFC 3986 unreserved character:
// ALPHA / DIGIT / "-" / "." / "_" / "~"
func isUnreserved(c byte) bool {
return (c >= 'A' && c <= 'Z') ||
(c >= 'a' && c <= 'z') ||
(c >= '0' && c <= '9') ||
c == '-' || c == '.' || c == '_' || c == '~'
}
Loading