Skip to content
Open
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 64 additions & 3 deletions drf_writable_nested/mixins.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
from collections import OrderedDict, defaultdict
from collections.abc import Mapping
from typing import List, Tuple

from django.contrib.contenttypes.fields import GenericRelation
Expand All @@ -10,10 +11,63 @@
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from rest_framework.exceptions import ValidationError
from rest_framework.fields import empty, set_value
from rest_framework.settings import api_settings
from rest_framework.validators import UniqueValidator


class BaseNestedModelSerializer(serializers.ModelSerializer):
class FastToInternalValueMixin:
def fast_to_internal_value(self, data):
"""
Dict of native values <- Dict of primitive datatypes.
Skips validation.
"""
if not isinstance(data, Mapping):
message = self.error_messages['invalid'].format(
datatype=type(data).__name__
)
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message]
}, code='invalid')

ret = OrderedDict()
fields = self._writable_fields

for field in fields:
primitive_value = field.get_value(data)
if primitive_value is empty:
continue

set_value(ret, field.source_attrs, primitive_value)

return ret


class NestedOnlySerializerMixin(FastToInternalValueMixin, serializers.ModelSerializer):
"""
Required for all serializers that are nested under BaseNestedModelSerializer.
"""

def save(self, **kwargs):
self._validated_data = self.fast_to_internal_value(self.initial_data)
self._save_kwargs = defaultdict(dict, kwargs)
validated_data = {**self.validated_data, **kwargs}

if self.instance is not None:
self.instance = self.update(self.instance, validated_data)
assert self.instance is not None, (
'`update()` did not return an object instance.'
)
else:
self.instance = self.create(validated_data)
assert self.instance is not None, (
'`create()` did not return an object instance.'
)

return self.instance


class BaseNestedModelSerializer(FastToInternalValueMixin, serializers.ModelSerializer):
def _extract_relations(self, validated_data):
reverse_relations = OrderedDict()
relations = OrderedDict()
Expand Down Expand Up @@ -134,6 +188,7 @@ def _prefetch_related_instances(self, field, related_data):

return instances


def update_or_create_reverse_relations(self, instance, reverse_relations):
# Update or create reverse relations:
# many-to-one, many-to-many, reversed one-to-one
Expand Down Expand Up @@ -183,7 +238,8 @@ def update_or_create_reverse_relations(self, instance, reverse_relations):
data=data,
)
try:
serializer.is_valid(raise_exception=True)
serializer._errors = {}
serializer._validated_data = self.fast_to_internal_value(self.initial_data)
related_instance = serializer.save(**save_kwargs)
data['pk'] = related_instance.pk
new_related_instances.append(related_instance)
Expand All @@ -208,18 +264,23 @@ def update_or_create_direct_relations(self, attrs, relations):
data = self.get_initial()[field_name]
model_class = field.Meta.model
pk = self._get_related_pk(data, model_class)
# pk needs to be specified if it's not one to one or creation of new object is not intended
if pk:
obj = model_class.objects.filter(
pk=pk,
).first()
elif hasattr(self.instance, field_source):
obj = getattr(self.instance, field_source)
serializer = self._get_serializer_for_field(
field,
instance=obj,
data=data,
)

try:
serializer.is_valid(raise_exception=True)

serializer._errors = {}
serializer._validated_data = self.fast_to_internal_value(self.initial_data)
attrs[field_source] = serializer.save(
**self._get_save_kwargs(field_name)
)
Expand Down