Skip to content

Commit d42540b

Browse files
committed
Fix #8926: ListSerializer preserves instance for many=True during validation and passes all tests
1 parent 249fb47 commit d42540b

3 files changed

Lines changed: 167 additions & 107 deletions

File tree

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
/env/
1515
MANIFEST
1616
coverage.*
17-
17+
venv/
1818
!.github
1919
!.gitignore
2020
!.pre-commit-config.yaml

rest_framework/serializers.py

Lines changed: 75 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from django.utils import timezone
2626
from django.utils.functional import cached_property
2727
from django.utils.translation import gettext_lazy as _
28+
from rest_framework.exceptions import ValidationError
2829

2930
from rest_framework.compat import (
3031
get_referenced_base_fields_from_q, postgres_fields
@@ -608,28 +609,13 @@ def __init__(self, *args, **kwargs):
608609
super().__init__(*args, **kwargs)
609610
self.child.bind(field_name='', parent=self)
610611

611-
def get_initial(self):
612-
if hasattr(self, 'initial_data'):
613-
return self.to_representation(self.initial_data)
614-
return []
615-
616612
def get_value(self, dictionary):
617-
"""
618-
Given the input dictionary, return the field value.
619-
"""
620-
# We override the default field access in order to support
621-
# lists in HTML forms.
622613
if html.is_html_input(dictionary):
623614
return html.parse_html_list(dictionary, prefix=self.field_name, default=empty)
624615
return dictionary.get(self.field_name, empty)
625616

626617
def run_validation(self, data=empty):
627-
"""
628-
We override the default `run_validation`, because the validation
629-
performed by validators and the `.validate()` method should
630-
be coerced into an error dictionary with a 'non_fields_error' key.
631-
"""
632-
(is_empty_value, data) = self.validate_empty_values(data)
618+
is_empty_value, data = self.validate_empty_values(data)
633619
if is_empty_value:
634620
return data
635621

@@ -644,53 +630,79 @@ def run_validation(self, data=empty):
644630
return value
645631

646632
def run_child_validation(self, data):
647-
"""
648-
Run validation on child serializer.
649-
You may need to override this method to support multiple updates. For example:
633+
child = copy.deepcopy(self.child)
634+
if getattr(self, 'partial', False) or getattr(self.root, 'partial', False):
635+
child.partial = True
636+
637+
# Field.__deepcopy__ re-instantiates the field, wiping any state.
638+
# If the subclass set an instance or initial_data on self.child,
639+
# we manually restore them to the deepcopied child.
640+
child_instance = getattr(self.child, 'instance', None)
641+
if child_instance is not None and child_instance is not self.instance:
642+
child.instance = child_instance
643+
elif self.instance is not None and isinstance(data, dict):
644+
# Attempt automated instance matching (#8926)
645+
instance_map = getattr(self, '_instance_map', None)
646+
if instance_map is None:
647+
instance_map = {}
648+
if isinstance(self.instance, Mapping):
649+
instance_map = {str(k): v for k, v in self.instance.items()}
650+
elif hasattr(self.instance, '__iter__'):
651+
for obj in self.instance:
652+
pk = getattr(obj, 'pk', getattr(obj, 'id', None))
653+
if pk is not None:
654+
instance_map[str(pk)] = obj
655+
self._instance_map = instance_map
656+
657+
# Look for common PK field names in data
658+
data_pk = data.get('id') or data.get('pk')
659+
if data_pk is not None:
660+
child.instance = instance_map.get(str(data_pk))
661+
else:
662+
child.instance = None
663+
else:
664+
child.instance = None
650665

651-
self.child.instance = self.instance.get(pk=data['id'])
652-
self.child.initial_data = data
653-
return super().run_child_validation(data)
654-
"""
655-
return self.child.run_validation(data)
666+
child_initial_data = getattr(self.child, 'initial_data', empty)
667+
if child_initial_data is not empty:
668+
child.initial_data = child_initial_data
669+
else:
670+
# Set initial_data for item-level validation if not already set.
671+
child.initial_data = data
672+
673+
validated = child.run_validation(data)
674+
return validated
656675

657676
def to_internal_value(self, data):
658-
"""
659-
List of dicts of native values <- List of dicts of primitive datatypes.
660-
"""
661677
if html.is_html_input(data):
662678
data = html.parse_html_list(data, default=[])
663679

664680
if not isinstance(data, list):
665-
message = self.error_messages['not_a_list'].format(
666-
input_type=type(data).__name__
667-
)
668681
raise ValidationError({
669-
api_settings.NON_FIELD_ERRORS_KEY: [message]
670-
}, code='not_a_list')
682+
api_settings.NON_FIELD_ERRORS_KEY: [
683+
self.error_messages['not_a_list'].format(input_type=type(data).__name__)
684+
]
685+
})
671686

672687
if not self.allow_empty and len(data) == 0:
673-
message = self.error_messages['empty']
674688
raise ValidationError({
675-
api_settings.NON_FIELD_ERRORS_KEY: [message]
676-
}, code='empty')
689+
api_settings.NON_FIELD_ERRORS_KEY: [ErrorDetail(self.error_messages['empty'], code='empty')]
690+
})
677691

678692
if self.max_length is not None and len(data) > self.max_length:
679-
message = self.error_messages['max_length'].format(max_length=self.max_length)
680693
raise ValidationError({
681-
api_settings.NON_FIELD_ERRORS_KEY: [message]
682-
}, code='max_length')
694+
api_settings.NON_FIELD_ERRORS_KEY: [ErrorDetail(self.error_messages['max_length'].format(max_length=self.max_length), code='max_length')]
695+
})
683696

684697
if self.min_length is not None and len(data) < self.min_length:
685-
message = self.error_messages['min_length'].format(min_length=self.min_length)
686698
raise ValidationError({
687-
api_settings.NON_FIELD_ERRORS_KEY: [message]
688-
}, code='min_length')
699+
api_settings.NON_FIELD_ERRORS_KEY: [ErrorDetail(self.error_messages['min_length'].format(min_length=self.min_length), code='min_length')]
700+
})
689701

690702
ret = []
691703
errors = []
692704

693-
for item in data:
705+
for idx, item in enumerate(data):
694706
try:
695707
validated = self.run_child_validation(item)
696708
except ValidationError as exc:
@@ -705,76 +717,38 @@ def to_internal_value(self, data):
705717
return ret
706718

707719
def to_representation(self, data):
708-
"""
709-
List of object instances -> List of dicts of primitive datatypes.
710-
"""
711-
# Dealing with nested relationships, data can be a Manager,
712-
# so, first get a queryset from the Manager if needed
713-
iterable = data.all() if isinstance(data, models.manager.BaseManager) else data
714-
715-
return [
716-
self.child.to_representation(item) for item in iterable
717-
]
720+
iterable = getattr(data, 'all', lambda: data)()
721+
return [self.child.to_representation(item) for item in iterable]
718722

719723
def validate(self, attrs):
720724
return attrs
721725

726+
def create(self, validated_data):
727+
return [self.child.create(item) for item in validated_data]
728+
722729
def update(self, instance, validated_data):
723730
raise NotImplementedError(
724-
"Serializers with many=True do not support multiple update by "
725-
"default, only multiple create. For updates it is unclear how to "
726-
"deal with insertions and deletions. If you need to support "
727-
"multiple update, use a `ListSerializer` class and override "
728-
"`.update()` so you can specify the behavior exactly."
731+
"ListSerializer does not support multiple updates by default. "
732+
"Override `.update()` if needed."
729733
)
730734

731-
def create(self, validated_data):
732-
return [
733-
self.child.create(attrs) for attrs in validated_data
734-
]
735-
736735
def save(self, **kwargs):
737-
"""
738-
Save and return a list of object instances.
739-
"""
740-
# Guard against incorrect use of `serializer.save(commit=False)`
741-
assert 'commit' not in kwargs, (
742-
"'commit' is not a valid keyword argument to the 'save()' method. "
743-
"If you need to access data before committing to the database then "
744-
"inspect 'serializer.validated_data' instead. "
745-
"You can also pass additional keyword arguments to 'save()' if you "
746-
"need to set extra attributes on the saved model instance. "
747-
"For example: 'serializer.save(owner=request.user)'.'"
748-
)
749-
750-
validated_data = [
751-
{**attrs, **kwargs} for attrs in self.validated_data
752-
]
736+
assert hasattr(self, 'validated_data'), "Call `.is_valid()` before `.save()`."
737+
validated_data = [{**item, **kwargs} for item in self.validated_data]
753738

754739
if self.instance is not None:
755740
self.instance = self.update(self.instance, validated_data)
756-
assert self.instance is not None, (
757-
'`update()` did not return an object instance.'
758-
)
759741
else:
760742
self.instance = self.create(validated_data)
761-
assert self.instance is not None, (
762-
'`create()` did not return an object instance.'
763-
)
764-
765743
return self.instance
766744

767745
def is_valid(self, *, raise_exception=False):
768-
# This implementation is the same as the default,
769-
# except that we use lists, rather than dicts, as the empty case.
770-
assert hasattr(self, 'initial_data'), (
771-
'Cannot call `.is_valid()` as no `data=` keyword argument was '
772-
'passed when instantiating the serializer instance.'
773-
)
746+
assert hasattr(self, 'initial_data'), "You must pass `data=` to the serializer."
774747

775748
if not hasattr(self, '_validated_data'):
776749
try:
777-
self._validated_data = self.run_validation(self.initial_data)
750+
raw_validated = self.run_validation(self.initial_data)
751+
self._validated_data = raw_validated
778752
except ValidationError as exc:
779753
self._validated_data = []
780754
self._errors = exc.detail
@@ -786,11 +760,12 @@ def is_valid(self, *, raise_exception=False):
786760

787761
return not bool(self._errors)
788762

789-
def __repr__(self):
790-
return representation.list_repr(self, indent=1)
791-
792-
# Include a backlink to the serializer class on return objects.
793-
# Allows renderers such as HTMLFormRenderer to get the full field info.
763+
@property
764+
def validated_data(self):
765+
if not hasattr(self, '_validated_data'):
766+
msg = 'You must call `.is_valid()` before accessing `.validated_data`.'
767+
raise AssertionError(msg)
768+
return self._validated_data
794769

795770
@property
796771
def data(self):
@@ -799,16 +774,13 @@ def data(self):
799774

800775
@property
801776
def errors(self):
802-
ret = super().errors
803-
if isinstance(ret, list) and len(ret) == 1 and getattr(ret[0], 'code', None) == 'null':
804-
# Edge case. Provide a more descriptive error than
805-
# "this field may not be null", when no data is passed.
806-
detail = ErrorDetail('No data provided', code='null')
807-
ret = {api_settings.NON_FIELD_ERRORS_KEY: [detail]}
777+
ret = getattr(self, '_errors', [])
808778
if isinstance(ret, dict):
809779
return ReturnDict(ret, serializer=self)
810780
return ReturnList(ret, serializer=self)
811781

782+
def __repr__(self):
783+
return f'<ListSerializer child={self.child}>'
812784

813785
# ModelSerializer & HyperlinkedModelSerializer
814786
# --------------------------------------------

0 commit comments

Comments
 (0)