|
18 | 18 | import builtins |
19 | 19 |
|
20 | 20 | from datetime import date, datetime |
21 | | -from types import MappingProxyType |
| 21 | +from types import MappingProxyType, UnionType |
22 | 22 | from typing import ( |
23 | 23 | TYPE_CHECKING, |
24 | 24 | Annotated, |
@@ -961,25 +961,32 @@ def _encode_value(self, type_hint: Any, value: Any) -> Any: |
961 | 961 | return value |
962 | 962 | # For every other field type, parse the possible value types |
963 | 963 | # from the type hint. |
| 964 | + # First, remove Annotated to get the actual attribute type. |
964 | 965 | type_hint_origin = get_type_origin(type_hint) or type_hint |
965 | 966 | attr_type = ( |
966 | 967 | get_type_args(type_hint)[0] |
967 | 968 | if type_hint_origin is Annotated |
968 | | - else type_hint_origin |
| 969 | + else type_hint |
969 | 970 | ) |
| 971 | + # Next, get a list of allowed values types from the attribute type. |
| 972 | + # If there is only one, create a list containing only that type. |
970 | 973 | attr_type_origin = get_type_origin(attr_type) or attr_type |
971 | 974 | value_types = ( |
972 | 975 | get_type_args(attr_type) |
973 | | - if attr_type_origin is Union |
974 | | - else [attr_type_origin] |
| 976 | + if attr_type_origin is Union or attr_type_origin is UnionType |
| 977 | + else [attr_type] |
975 | 978 | ) |
976 | 979 | # Recursively handle the types that need to be serialised. |
977 | 980 | for value_type in value_types: |
| 981 | + value_type_origin = get_type_origin(value_type) or value_type |
978 | 982 | if value_type is date and isinstance(value, date): |
979 | 983 | return value.strftime(DEFAULT_SERVER_DATE_FORMAT) |
980 | 984 | if value_type is datetime and isinstance(value, datetime): |
981 | 985 | return value.strftime(DEFAULT_SERVER_DATETIME_FORMAT) |
982 | | - if value_type is list and isinstance(value, (list, set, tuple)): |
983 | | - v_type = get_type_args(type_hint)[0] |
| 986 | + if value_type_origin is list and isinstance( |
| 987 | + value, |
| 988 | + (list, set, tuple), |
| 989 | + ): |
| 990 | + v_type = get_type_args(value_type)[0] |
984 | 991 | return [self._encode_value(v_type, v) for v in value] |
985 | 992 | return value |
0 commit comments