1from __future__ import unicode_literals
2
3from django.core import checks
4from django.db import IntegrityError, connections, router
5from django.db.models import CASCADE
6from django.db.models.fields.related import ForeignKey, ManyToManyField
7from django.utils.functional import cached_property
8
9from django.db.models.fields.related import ReverseManyToOneDescriptor, ManyToManyDescriptor
10
11
12from modelcluster.utils import sort_by_fields
13
14from modelcluster.queryset import FakeQuerySet
15
16
17def create_deferring_foreign_related_manager(related, original_manager_cls):
18    """
19    Create a DeferringRelatedManager class that wraps an ordinary RelatedManager
20    with 'deferring' behaviour: any updates to the object set (via e.g. add() or clear())
21    are written to a holding area rather than committed to the database immediately.
22    Writing to the database is deferred until the model is saved.
23    """
24
25    relation_name = related.get_accessor_name()
26    rel_field = related.field
27    rel_model = related.related_model
28    superclass = rel_model._default_manager.__class__
29
30    class DeferringRelatedManager(superclass):
31        def __init__(self, instance):
32            super(DeferringRelatedManager, self).__init__()
33            self.model = rel_model
34            self.instance = instance
35
36        def _get_cluster_related_objects(self):
37            # Helper to retrieve the instance's _cluster_related_objects dict,
38            # creating it if it does not already exist
39            try:
40                return self.instance._cluster_related_objects
41            except AttributeError:
42                cluster_related_objects = {}
43                self.instance._cluster_related_objects = cluster_related_objects
44                return cluster_related_objects
45
46        def get_live_query_set(self):
47            # deprecated; renamed to get_live_queryset to match the move from
48            # get_query_set to get_queryset in Django 1.6
49            return self.get_live_queryset()
50
51        def get_live_queryset(self):
52            """
53            return the original manager's queryset, which reflects the live database
54            """
55            return original_manager_cls(self.instance).get_queryset()
56
57        def get_queryset(self):
58            """
59            return the current object set with any updates applied,
60            wrapped up in a FakeQuerySet if it doesn't match the database state
61            """
62            try:
63                results = self.instance._cluster_related_objects[relation_name]
64            except (AttributeError, KeyError):
65                return self.get_live_queryset()
66
67            return FakeQuerySet(related.related_model, results)
68
69        def _apply_rel_filters(self, queryset):
70            # Implemented as empty for compatibility sake
71            # But there is probably a better implementation of this function
72            #
73            # NOTE: _apply_rel_filters() must return a copy of the queryset
74            # to work correctly with prefetch
75            return queryset._next_is_sticky().all()
76
77        def get_prefetch_queryset(self, instances, queryset=None):
78            if queryset is None:
79                db = self._db or router.db_for_read(self.model, instance=instances[0])
80                queryset = super(DeferringRelatedManager, self).get_queryset().using(db)
81
82            rel_obj_attr = rel_field.get_local_related_value
83            instance_attr = rel_field.get_foreign_related_value
84            instances_dict = dict((instance_attr(inst), inst) for inst in instances)
85
86            query = {'%s__in' % rel_field.name: instances}
87            qs = queryset.filter(**query)
88            # Since we just bypassed this class' get_queryset(), we must manage
89            # the reverse relation manually.
90            for rel_obj in qs:
91                instance = instances_dict[rel_obj_attr(rel_obj)]
92                setattr(rel_obj, rel_field.name, instance)
93            cache_name = rel_field.related_query_name()
94            return qs, rel_obj_attr, instance_attr, False, cache_name, False
95
96        def get_object_list(self):
97            """
98            return the mutable list that forms the current in-memory state of
99            this relation. If there is no such list (i.e. the manager is returning
100            querysets from the live database instead), one is created, populating it
101            with the live database state
102            """
103            cluster_related_objects = self._get_cluster_related_objects()
104
105            try:
106                object_list = cluster_related_objects[relation_name]
107            except KeyError:
108                object_list = list(self.get_live_queryset())
109                cluster_related_objects[relation_name] = object_list
110
111            return object_list
112
113        def add(self, *new_items):
114            """
115            Add the passed items to the stored object set, but do not commit them
116            to the database
117            """
118            items = self.get_object_list()
119
120            for target in new_items:
121                item_matched = False
122                for i, item in enumerate(items):
123                    if item == target:
124                        # Replace the matched item with the new one. This ensures that any
125                        # modifications to that item's fields take effect within the recordset -
126                        # i.e. we can perform a virtual UPDATE to an object in the list
127                        # by calling add(updated_object). Which is semantically a bit dubious,
128                        # but it does the job...
129                        items[i] = target
130                        item_matched = True
131                        break
132                if not item_matched:
133                    items.append(target)
134
135                # update the foreign key on the added item to point back to the parent instance
136                setattr(target, related.field.name, self.instance)
137
138            # Sort list
139            if rel_model._meta.ordering and len(items) > 1:
140                sort_by_fields(items, rel_model._meta.ordering)
141
142        def remove(self, *items_to_remove):
143            """
144            Remove the passed items from the stored object set, but do not commit the change
145            to the database
146            """
147            items = self.get_object_list()
148
149            # filter items list in place: see http://stackoverflow.com/a/1208792/1853523
150            items[:] = [item for item in items if item not in items_to_remove]
151
152        def create(self, **kwargs):
153            items = self.get_object_list()
154            new_item = related.related_model(**kwargs)
155            items.append(new_item)
156            return new_item
157
158        def clear(self):
159            """
160            Clear the stored object set, without affecting the database
161            """
162            self.set([])
163
164        def set(self, objs, bulk=True, clear=False):
165            # cast objs to a list so that:
166            # 1) we can call len() on it (which we can't do on, say, a queryset)
167            # 2) if we need to sort it, we can do so without mutating the original
168            objs = list(objs)
169
170            cluster_related_objects = self._get_cluster_related_objects()
171
172            for obj in objs:
173                # update the foreign key on the added item to point back to the parent instance
174                setattr(obj, related.field.name, self.instance)
175
176            # Clone and sort the 'objs' list, if necessary
177            if rel_model._meta.ordering and len(objs) > 1:
178                sort_by_fields(objs, rel_model._meta.ordering)
179
180            cluster_related_objects[relation_name] = objs
181
182        def commit(self):
183            """
184            Apply any changes made to the stored object set to the database.
185            Any objects removed from the initial set will be deleted entirely
186            from the database.
187            """
188            if self.instance.pk is None:
189                raise IntegrityError("Cannot commit relation %r on an unsaved model" % relation_name)
190
191            try:
192                final_items = self.instance._cluster_related_objects[relation_name]
193            except (AttributeError, KeyError):
194                # _cluster_related_objects entry never created => no changes to make
195                return
196
197            original_manager = original_manager_cls(self.instance)
198
199            live_items = list(original_manager.get_queryset())
200            for item in live_items:
201                if item not in final_items:
202                    item.delete()
203
204            for item in final_items:
205                # Django 1.9+ bulk updates items by default which assumes
206                # that they have already been saved to the database.
207                # Disable this behaviour.
208                # https://code.djangoproject.com/ticket/18556
209                # https://github.com/django/django/commit/adc0c4fbac98f9cb975e8fa8220323b2de638b46
210                original_manager.add(item, bulk=False)
211
212            # purge the _cluster_related_objects entry, so we switch back to live SQL
213            del self.instance._cluster_related_objects[relation_name]
214
215    return DeferringRelatedManager
216
217
218class ChildObjectsDescriptor(ReverseManyToOneDescriptor):
219    def __get__(self, instance, instance_type=None):
220        if instance is None:
221            return self
222
223        return self.child_object_manager_cls(instance)
224
225    def __set__(self, instance, value):
226        manager = self.__get__(instance)
227        manager.set(value)
228
229    @cached_property
230    def child_object_manager_cls(self):
231        return create_deferring_foreign_related_manager(self.rel, self.related_manager_cls)
232
233
234class ParentalKey(ForeignKey):
235    related_accessor_class = ChildObjectsDescriptor
236
237    def __init__(self, *args, **kwargs):
238        kwargs.setdefault('on_delete', CASCADE)
239        super(ParentalKey, self).__init__(*args, **kwargs)
240
241    def check(self, **kwargs):
242        from modelcluster.models import ClusterableModel
243
244        errors = super(ParentalKey, self).check(**kwargs)
245
246        # Check that the destination model is a subclass of ClusterableModel.
247        # If self.rel.to is a string at this point, it means that Django has been unable
248        # to resolve it as a model name; if so, skip this test so that Django's own
249        # system checks can report the appropriate error
250        if isinstance(self.remote_field.model, type) and not issubclass(self.remote_field.model, ClusterableModel):
251            errors.append(
252                checks.Error(
253                    'ParentalKey must point to a subclass of ClusterableModel.',
254                    hint='Change {model_name} into a ClusterableModel or use a ForeignKey instead.'.format(
255                        model_name=self.remote_field.model._meta.app_label + '.' + self.remote_field.model.__name__,
256                    ),
257                    obj=self,
258                    id='modelcluster.E001',
259                )
260            )
261
262        # ParentalKeys must have an accessor name (#49)
263        if self.remote_field.get_accessor_name() == '+':
264            errors.append(
265                checks.Error(
266                    "related_name='+' is not allowed on ParentalKey fields",
267                    hint="Either change it to a valid name or remove it",
268                    obj=self,
269                    id='modelcluster.E002',
270                )
271            )
272
273        return errors
274
275
276def create_deferring_forward_many_to_many_manager(rel, original_manager_cls):
277    rel_field = rel.field
278    relation_name = rel_field.name
279    query_field_name = rel_field.related_query_name()
280    source_field_name = rel_field.m2m_field_name()
281    rel_model = rel.model
282    superclass = rel_model._default_manager.__class__
283    rel_through = rel.through
284
285    class DeferringManyRelatedManager(superclass):
286        def __init__(self, instance=None):
287            super(DeferringManyRelatedManager, self).__init__()
288            self.model = rel_model
289            self.through = rel_through
290            self.instance = instance
291
292        def get_original_manager(self):
293            return original_manager_cls(self.instance)
294
295        def get_live_queryset(self):
296            """
297            return the original manager's queryset, which reflects the live database
298            """
299            return self.get_original_manager().get_queryset()
300
301        def _get_cluster_related_objects(self):
302            # Helper to retrieve the instance's _cluster_related_objects dict,
303            # creating it if it does not already exist
304            try:
305                return self.instance._cluster_related_objects
306            except AttributeError:
307                cluster_related_objects = {}
308                self.instance._cluster_related_objects = cluster_related_objects
309                return cluster_related_objects
310
311        def get_queryset(self):
312            """
313            return the current object set with any updates applied,
314            wrapped up in a FakeQuerySet if it doesn't match the database state
315            """
316            try:
317                results = self.instance._cluster_related_objects[relation_name]
318            except (AttributeError, KeyError):
319                if self.instance.pk:
320                    return self.get_live_queryset()
321                else:
322                    # the standard M2M manager fails on unsaved instances,
323                    # so bypass it and return an empty queryset
324                    return rel_model.objects.none()
325
326            return FakeQuerySet(rel_model, results)
327
328        def get_prefetch_queryset(self, instances, queryset=None):
329            # Derived from Django's ManyRelatedManager.get_prefetch_queryset.
330            if queryset is None:
331                queryset = super().get_queryset()
332
333            queryset._add_hints(instance=instances[0])
334            queryset = queryset.using(queryset._db or self._db)
335
336            query = {'%s__in' % query_field_name: instances}
337            queryset = queryset._next_is_sticky().filter(**query)
338
339            fk = self.through._meta.get_field(source_field_name)
340            join_table = fk.model._meta.db_table
341
342            connection = connections[queryset.db]
343            qn = connection.ops.quote_name
344
345            queryset = queryset.extra(select={
346                '_prefetch_related_val_%s' % f.attname:
347                '%s.%s' % (qn(join_table), qn(f.column)) for f in fk.local_related_fields})
348
349            return (
350                queryset,
351                lambda result: tuple(
352                    getattr(result, '_prefetch_related_val_%s' % f.attname)
353                    for f in fk.local_related_fields
354                ),
355                lambda inst: tuple(
356                    f.get_db_prep_value(getattr(inst, f.attname), connection)
357                    for f in fk.foreign_related_fields
358                ),
359                False,
360                relation_name,
361                False,
362            )
363
364        def _apply_rel_filters(self, queryset):
365            # Required for get_prefetch_queryset.
366            return queryset._next_is_sticky()
367
368        def get_object_list(self):
369            """
370            return the mutable list that forms the current in-memory state of
371            this relation. If there is no such list (i.e. the manager is returning
372            querysets from the live database instead), one is created, populating it
373            with the live database state
374            """
375            cluster_related_objects = self._get_cluster_related_objects()
376
377            try:
378                object_list = cluster_related_objects[relation_name]
379            except KeyError:
380                object_list = list(self.get_live_queryset())
381                cluster_related_objects[relation_name] = object_list
382
383            return object_list
384
385        def add(self, *new_items):
386            """
387            Add the passed items to the stored object set, but do not commit them
388            to the database
389            """
390            items = self.get_object_list()
391
392            for target in new_items:
393                if target.pk is None:
394                    raise ValueError('"%r" needs to have a primary key value before '
395                        'it can be added to a parental many-to-many relation.' % target)
396                item_matched = False
397                for i, item in enumerate(items):
398                    if item == target:
399                        # Replace the matched item with the new one. This ensures that any
400                        # modifications to that item's fields take effect within the recordset -
401                        # i.e. we can perform a virtual UPDATE to an object in the list
402                        # by calling add(updated_object). Which is semantically a bit dubious,
403                        # but it does the job...
404                        items[i] = target
405                        item_matched = True
406                        break
407                if not item_matched:
408                    items.append(target)
409
410            # Sort list
411            if rel_model._meta.ordering and len(items) > 1:
412                sort_by_fields(items, rel_model._meta.ordering)
413
414        def clear(self):
415            """
416            Clear the stored object set, without affecting the database
417            """
418            self.set([])
419
420        def set(self, objs, bulk=True, clear=False):
421            # cast objs to a list so that:
422            # 1) we can call len() on it (which we can't do on, say, a queryset)
423            # 2) if we need to sort it, we can do so without mutating the original
424            objs = list(objs)
425
426            if objs and not isinstance(objs[0], rel_model):
427                # assume objs is a list of pks (like when loading data from a
428                # fixture), and allow the orignal manager to handle things
429                original_manager = self.get_original_manager()
430                original_manager.set(objs)
431                return
432
433            cluster_related_objects = self._get_cluster_related_objects()
434
435            # Clone and sort the 'objs' list, if necessary
436            if rel_model._meta.ordering and len(objs) > 1:
437                sort_by_fields(objs, rel_model._meta.ordering)
438
439            cluster_related_objects[relation_name] = objs
440
441        def remove(self, *items_to_remove):
442            """
443            Remove the passed items from the stored object set, but do not commit the change
444            to the database
445            """
446            items = self.get_object_list()
447
448            # filter items list in place: see http://stackoverflow.com/a/1208792/1853523
449            items[:] = [item for item in items if item not in items_to_remove]
450
451        def commit(self):
452            """
453            Apply any changes made to the stored object set to the database.
454            """
455            if not self.instance.pk:
456                raise IntegrityError("Cannot commit relation %r on an unsaved model" % relation_name)
457
458            try:
459                final_items = self.instance._cluster_related_objects[relation_name]
460            except (AttributeError, KeyError):
461                # _cluster_related_objects entry never created => no changes to make
462                return
463
464            original_manager = self.get_original_manager()
465            live_items = list(original_manager.get_queryset())
466
467            items_to_remove = [item for item in live_items if item not in final_items]
468            items_to_add = [item for item in final_items if item not in live_items]
469
470            if items_to_remove:
471                original_manager.remove(*items_to_remove)
472            if items_to_add:
473                original_manager.add(*items_to_add)
474
475            # purge the _cluster_related_objects entry, so we switch back to live SQL
476            del self.instance._cluster_related_objects[relation_name]
477
478    return DeferringManyRelatedManager
479
480
481class ParentalManyToManyDescriptor(ManyToManyDescriptor):
482    def __get__(self, instance, instance_type=None):
483        if instance is None:
484            return self
485
486        return self.child_object_manager_cls(instance)
487
488    def __set__(self, instance, value):
489        manager = self.__get__(instance)
490        manager.set(value)
491
492    @cached_property
493    def child_object_manager_cls(self):
494        rel = self.rel
495
496        return create_deferring_forward_many_to_many_manager(rel, self.related_manager_cls)
497
498
499class ParentalManyToManyField(ManyToManyField):
500    related_accessor_class = ParentalManyToManyDescriptor
501    _need_commit_after_assignment = True
502
503    def contribute_to_class(self, cls, name, **kwargs):
504        # ManyToManyField does not (as of Django 1.10) respect related_accessor_class,
505        # but hard-codes ManyToManyDescriptor instead:
506        # https://github.com/django/django/blob/6157cd6da1b27716e8f3d1ed692a6e33d970ae46/django/db/models/fields/related.py#L1538
507        # So, we'll let the original contribute_to_class do its thing, and then overwrite
508        # the accessor...
509        super(ParentalManyToManyField, self).contribute_to_class(cls, name, **kwargs)
510        setattr(cls, self.name, self.related_accessor_class(self.remote_field))
511
512    def value_from_object(self, obj):
513        # In Django >=1.10, ManyToManyField.value_from_object special-cases objects with no PK,
514        # returning an empty list on the basis that unsaved objects can't have related objects.
515        # Remove that special case.
516        return getattr(obj, self.attname).all()
517