1import operator
2import threading
3from functools import reduce, wraps
4
5from django.db import models
6from django.db.models.base import ModelBase
7from django.db.models.query import Q
8from django.db.models.query_utils import DeferredAttribute
9from django.utils.translation import gettext as _
10
11from mptt.compat import cached_field_value
12from mptt.fields import TreeForeignKey, TreeManyToManyField, TreeOneToOneField
13from mptt.managers import TreeManager
14from mptt.signals import node_moved
15from mptt.utils import _get_tree_model
16
17
18__all__ = (
19    "TreeForeignKey",
20    "TreeOneToOneField",
21    "TreeManyToManyField",
22    "TreeManager",
23    "MPTTOptions",
24    "MPTTModelBase",
25    "MPTTModel",
26)
27
28
29class _classproperty:
30    def __init__(self, getter, setter=None):
31        self.fget = getter
32        self.fset = setter
33
34    def __get__(self, cls, owner):
35        return self.fget(owner)
36
37    def __set__(self, cls, owner, value):
38        if not self.fset:
39            raise AttributeError("This classproperty is read only")
40        self.fset(owner, value)
41
42
43class classpropertytype(property):
44    def __init__(self, name, bases=(), members={}):
45        return super().__init__(
46            members.get("__get__"),
47            members.get("__set__"),
48            members.get("__delete__"),
49            members.get("__doc__"),
50        )
51
52
53classproperty = classpropertytype("classproperty")
54
55
56class MPTTOptions:
57    """
58    Options class for MPTT models. Use this as an inner class called ``MPTTMeta``::
59
60        class MyModel(MPTTModel):
61            class MPTTMeta:
62                order_insertion_by = ['name']
63                parent_attr = 'myparent'
64    """
65
66    order_insertion_by = []
67    left_attr = "lft"
68    right_attr = "rght"
69    tree_id_attr = "tree_id"
70    level_attr = "level"
71    parent_attr = "parent"
72
73    def __init__(self, opts=None, **kwargs):
74        # Override defaults with options provided
75        if opts:
76            opts = list(opts.__dict__.items())
77        else:
78            opts = []
79        opts.extend(list(kwargs.items()))
80
81        if "tree_manager_attr" in [opt[0] for opt in opts]:
82            raise ValueError(
83                "`tree_manager_attr` has been removed; you should instantiate"
84                " a TreeManager as a normal manager on your model instead."
85            )
86
87        for key, value in opts:
88            if key[:2] == "__":
89                continue
90            setattr(self, key, value)
91
92        # Normalize order_insertion_by to a list
93        if isinstance(self.order_insertion_by, str):
94            self.order_insertion_by = [self.order_insertion_by]
95        elif isinstance(self.order_insertion_by, tuple):
96            self.order_insertion_by = list(self.order_insertion_by)
97        elif self.order_insertion_by is None:
98            self.order_insertion_by = []
99
100    def __iter__(self):
101        return ((k, v) for k, v in self.__dict__.items() if k[0] != "_")
102
103    # Helper methods for accessing tree attributes on models.
104    def get_raw_field_value(self, instance, field_name):
105        """
106        Gets the value of the given fieldname for the instance.
107        This is not the same as getattr().
108        This function will return IDs for foreignkeys etc, rather than doing
109        a database query.
110        """
111        field = instance._meta.get_field(field_name)
112        return field.value_from_object(instance)
113
114    def set_raw_field_value(self, instance, field_name, value):
115        """
116        Sets the value of the given fieldname for the instance.
117        This is not the same as setattr().
118        This function requires an ID for a foreignkey (etc) rather than an instance.
119        """
120        field = instance._meta.get_field(field_name)
121        setattr(instance, field.attname, value)
122
123    def update_mptt_cached_fields(self, instance):
124        """
125        Caches (in an instance._mptt_cached_fields dict) the original values of:
126         - parent pk
127         - fields specified in order_insertion_by
128
129        These are used in save() to determine if the relevant fields have changed,
130        so that the MPTT fields need to be updated.
131        """
132        instance._mptt_cached_fields = {}
133        field_names = {self.parent_attr}
134        if self.order_insertion_by:
135            for f in self.order_insertion_by:
136                if f[0] == "-":
137                    f = f[1:]
138                field_names.add(f)
139        deferred_fields = instance.get_deferred_fields()
140        for field_name in field_names:
141            if deferred_fields:
142                field = instance._meta.get_field(field_name)
143                if (
144                    field.attname in deferred_fields
145                    and field.attname not in instance.__dict__
146                ):
147                    # deferred attribute (i.e. via .only() or .defer())
148                    # It'd be silly to cache this (that'd do a database query)
149                    # Instead, we mark it as a deferred attribute here, then
150                    # assume it hasn't changed during save(), unless it's no
151                    # longer deferred.
152                    instance._mptt_cached_fields[field_name] = DeferredAttribute
153                    continue
154            instance._mptt_cached_fields[field_name] = self.get_raw_field_value(
155                instance, field_name
156            )
157
158    def insertion_target_filters(self, instance, order_insertion_by):
159        """
160        Creates a filter which matches suitable right siblings for ``node``,
161        where insertion should maintain ordering according to the list of
162        fields in ``order_insertion_by``.
163
164        For example, given an ``order_insertion_by`` of
165        ``['field1', 'field2', 'field3']``, the resulting filter should
166        correspond to the following SQL::
167
168           field1 > %s
169           OR (field1 = %s AND field2 > %s)
170           OR (field1 = %s AND field2 = %s AND field3 > %s)
171
172        """
173        fields = []
174        filters = []
175        fields__append = fields.append
176        filters__append = filters.append
177        and_ = operator.and_
178        or_ = operator.or_
179        for field_name in order_insertion_by:
180            if field_name[0] == "-":
181                field_name = field_name[1:]
182                filter_suffix = "__lt"
183            else:
184                filter_suffix = "__gt"
185            value = getattr(instance, field_name)
186            if value is None:
187                # node isn't saved yet. get the insertion value from pre_save.
188                field = instance._meta.get_field(field_name)
189                value = field.pre_save(instance, True)
190
191            if value is None:
192                # we have to use __isnull instead of __lt or __gt becase __lt = Null is invalid
193                # depending on order, we need to find the first node where code is null or not null
194                value = filter_suffix == "__lt"
195                filter_suffix = "__isnull"
196
197            q = Q(**{field_name + filter_suffix: value})
198
199            filters__append(reduce(and_, [Q(**{f: v}) for f, v in fields] + [q]))
200            fields__append((field_name, value))
201        return reduce(or_, filters)
202
203    def get_ordered_insertion_target(self, node, parent):
204        """
205        Attempts to retrieve a suitable right sibling for ``node``
206        underneath ``parent`` (which may be ``None`` in the case of root
207        nodes) so that ordering by the fields specified by the node's class'
208        ``order_insertion_by`` option is maintained.
209
210        Returns ``None`` if no suitable sibling can be found.
211        """
212        right_sibling = None
213        # Optimisation - if the parent doesn't have descendants,
214        # the node will always be its last child.
215        if self.order_insertion_by and (
216            parent is None or parent.get_descendant_count() > 0
217        ):
218            opts = node._mptt_meta
219            order_by = opts.order_insertion_by[:]
220            filters = self.insertion_target_filters(node, order_by)
221            if parent:
222                filters = filters & Q(**{opts.parent_attr: parent})
223                # Fall back on tree ordering if multiple child nodes have
224                # the same values.
225                order_by.append(opts.left_attr)
226            else:
227                filters = filters & Q(**{opts.parent_attr: None})
228                # Fall back on tree id ordering if multiple root nodes have
229                # the same values.
230                order_by.append(opts.tree_id_attr)
231            queryset = (
232                node.__class__._tree_manager.db_manager(node._state.db)
233                .filter(filters)
234                .order_by(*order_by)
235            )
236            if node.pk:
237                queryset = queryset.exclude(pk=node.pk)
238            try:
239                right_sibling = queryset[:1][0]
240            except IndexError:
241                # No suitable right sibling could be found
242                pass
243        return right_sibling
244
245
246class MPTTModelBase(ModelBase):
247    """
248    Metaclass for MPTT models
249    """
250
251    def __new__(meta, class_name, bases, class_dict):
252        """
253        Create subclasses of MPTTModel. This:
254         - adds the MPTT fields to the class
255         - adds a TreeManager to the model
256        """
257        if class_name == "NewBase" and class_dict == {}:
258            return super().__new__(meta, class_name, bases, class_dict)
259        is_MPTTModel = False
260        try:
261            MPTTModel
262        except NameError:
263            is_MPTTModel = True
264
265        MPTTMeta = class_dict.pop("MPTTMeta", None)
266        if not MPTTMeta:
267
268            class MPTTMeta:
269                pass
270
271        initial_options = frozenset(dir(MPTTMeta))
272
273        # extend MPTTMeta from base classes
274        for base in bases:
275            if hasattr(base, "_mptt_meta"):
276                for name, value in base._mptt_meta:
277                    if name == "tree_manager_attr":
278                        continue
279                    if name not in initial_options:
280                        setattr(MPTTMeta, name, value)
281
282        class_dict["_mptt_meta"] = MPTTOptions(MPTTMeta)
283        super_new = super().__new__
284        cls = super_new(meta, class_name, bases, class_dict)
285        cls = meta.register(cls)
286
287        # see error cases in TreeManager.disable_mptt_updates for the reasoning here.
288        cls._mptt_tracking_base = None
289        if is_MPTTModel:
290            bases = [cls]
291        else:
292            bases = [base for base in cls.mro() if issubclass(base, MPTTModel)]
293        for base in bases:
294            if (
295                not (base._meta.abstract or base._meta.proxy)
296                and base._tree_manager.tree_model is base
297            ):
298                cls._mptt_tracking_base = base
299                break
300        if cls is cls._mptt_tracking_base:
301            cls._threadlocal = threading.local()
302            # set on first access (to make threading errors more obvious):
303            #    cls._threadlocal.mptt_delayed_tree_changes = None
304
305        return cls
306
307    @classmethod
308    def register(meta, cls, **kwargs):
309        """
310        For the weird cases when you need to add tree-ness to an *existing*
311        class. For other cases you should subclass MPTTModel instead of calling this.
312        """
313
314        if not issubclass(cls, models.Model):
315            raise ValueError(_("register() expects a Django model class argument"))
316
317        if not hasattr(cls, "_mptt_meta"):
318            cls._mptt_meta = MPTTOptions(**kwargs)
319
320        abstract = getattr(cls._meta, "abstract", False)
321
322        try:
323            MPTTModel
324        except NameError:
325            # We're defining the base class right now, so don't do anything
326            # We only want to add this stuff to the subclasses.
327            # (Otherwise if field names are customized, we'll end up adding two
328            # copies)
329            pass
330        else:
331            if not issubclass(cls, MPTTModel):
332                bases = list(cls.__bases__)
333
334                # strip out bases that are strict superclasses of MPTTModel.
335                # (i.e. Model, object)
336                # this helps linearize the type hierarchy if possible
337                for i in range(len(bases) - 1, -1, -1):
338                    if issubclass(MPTTModel, bases[i]):
339                        del bases[i]
340
341                bases.insert(0, MPTTModel)
342                cls.__bases__ = tuple(bases)
343
344            is_cls_tree_model = _get_tree_model(cls) is cls
345
346            if is_cls_tree_model:
347                # HACK: _meta.get_field() doesn't work before AppCache.ready in Django>=1.8
348                # ( see https://code.djangoproject.com/ticket/24231 )
349                # So the only way to get existing fields is using local_fields on all superclasses.
350                existing_field_names = set()
351                for base in cls.mro():
352                    if hasattr(base, "_meta"):
353                        existing_field_names.update(
354                            [f.name for f in base._meta.local_fields]
355                        )
356
357                mptt_meta = cls._mptt_meta
358                indexed_attrs = (mptt_meta.tree_id_attr,)
359                field_names = (
360                    mptt_meta.left_attr,
361                    mptt_meta.right_attr,
362                    mptt_meta.tree_id_attr,
363                    mptt_meta.level_attr,
364                )
365
366                for field_name in field_names:
367                    if field_name not in existing_field_names:
368                        field = models.PositiveIntegerField(
369                            db_index=field_name in indexed_attrs, editable=False
370                        )
371                        field.contribute_to_class(cls, field_name)
372
373                # Add an index_together on tree_id_attr and left_attr, as these are very
374                # commonly queried (pretty much all reads).
375                index_together = (cls._mptt_meta.tree_id_attr, cls._mptt_meta.left_attr)
376                if index_together not in cls._meta.index_together:
377                    cls._meta.index_together += (index_together,)
378
379            # Add a tree manager, if there isn't one already
380            if not abstract:
381                # make sure we have a tree manager somewhere
382                tree_manager = None
383                # Use the default manager defined on the class if any
384                if cls._default_manager and isinstance(
385                    cls._default_manager, TreeManager
386                ):
387                    tree_manager = cls._default_manager
388                else:
389                    for cls_manager in cls._meta.managers:
390                        if isinstance(cls_manager, TreeManager):
391                            # prefer any locally defined manager (i.e. keep going if not local)
392                            if cls_manager.model is cls:
393                                tree_manager = cls_manager
394                                break
395
396                if is_cls_tree_model:
397                    idx_together = (
398                        cls._mptt_meta.tree_id_attr,
399                        cls._mptt_meta.left_attr,
400                    )
401
402                    if idx_together not in cls._meta.index_together:
403                        cls._meta.index_together += (idx_together,)
404
405                if tree_manager and tree_manager.model is not cls:
406                    tree_manager = tree_manager._copy_to_model(cls)
407                elif tree_manager is None:
408                    tree_manager = TreeManager()
409                tree_manager.contribute_to_class(cls, "_tree_manager")
410
411                # avoid using ManagerDescriptor, so instances can refer to self._tree_manager
412                setattr(cls, "_tree_manager", tree_manager)
413        return cls
414
415
416def raise_if_unsaved(func):
417    @wraps(func)
418    def _fn(self, *args, **kwargs):
419        if self._state.adding:
420            raise ValueError(
421                "Cannot call %(function)s on unsaved %(class)s instances"
422                % {"function": func.__name__, "class": self.__class__.__name__}
423            )
424        return func(self, *args, **kwargs)
425
426    return _fn
427
428
429class MPTTModel(models.Model, metaclass=MPTTModelBase):
430    """
431    Base class for tree models.
432    """
433
434    class Meta:
435        abstract = True
436
437    objects = TreeManager()
438
439    def __init__(self, *args, **kwargs):
440        super().__init__(*args, **kwargs)
441        self._mptt_meta.update_mptt_cached_fields(self)
442
443    def _mpttfield(self, fieldname):
444        translated_fieldname = getattr(self._mptt_meta, fieldname + "_attr")
445        return getattr(self, translated_fieldname)
446
447    @_classproperty
448    def _mptt_updates_enabled(cls):
449        if not cls._mptt_tracking_base:
450            return True
451        return getattr(
452            cls._mptt_tracking_base._threadlocal, "mptt_updates_enabled", True
453        )
454
455    # ideally this'd be part of the _mptt_updates_enabled classproperty, but it seems
456    # that settable classproperties are very, very hard to do! suggestions please :)
457    @classmethod
458    def _set_mptt_updates_enabled(cls, value):
459        assert (
460            cls is cls._mptt_tracking_base
461        ), "Can't enable or disable mptt updates on a non-tracking class."
462        cls._threadlocal.mptt_updates_enabled = value
463
464    @_classproperty
465    def _mptt_is_tracking(cls):
466        if not cls._mptt_tracking_base:
467            return False
468        if not hasattr(cls._threadlocal, "mptt_delayed_tree_changes"):
469            # happens the first time this is called from each thread
470            cls._threadlocal.mptt_delayed_tree_changes = None
471        return cls._threadlocal.mptt_delayed_tree_changes is not None
472
473    @classmethod
474    def _mptt_start_tracking(cls):
475        assert (
476            cls is cls._mptt_tracking_base
477        ), "Can't start or stop mptt tracking on a non-tracking class."
478        assert not cls._mptt_is_tracking, "mptt tracking is already started."
479        cls._threadlocal.mptt_delayed_tree_changes = set()
480
481    @classmethod
482    def _mptt_stop_tracking(cls):
483        assert (
484            cls is cls._mptt_tracking_base
485        ), "Can't start or stop mptt tracking on a non-tracking class."
486        assert cls._mptt_is_tracking, "mptt tracking isn't started."
487        results = cls._threadlocal.mptt_delayed_tree_changes
488        cls._threadlocal.mptt_delayed_tree_changes = None
489        return results
490
491    @classmethod
492    def _mptt_track_tree_modified(cls, tree_id):
493        if not cls._mptt_is_tracking:
494            return
495        cls._threadlocal.mptt_delayed_tree_changes.add(tree_id)
496
497    @classmethod
498    def _mptt_track_tree_insertions(cls, tree_id, num_inserted):
499        if not cls._mptt_is_tracking:
500            return
501        changes = cls._threadlocal.mptt_delayed_tree_changes
502        if not num_inserted or not changes:
503            return
504
505        if num_inserted < 0:
506            deleted = range(tree_id + num_inserted, -num_inserted)
507            changes.difference_update(deleted)
508        new_changes = {(t + num_inserted if t >= tree_id else t) for t in changes}
509        cls._threadlocal.mptt_delayed_tree_changes = new_changes
510
511    @raise_if_unsaved
512    def get_ancestors(self, ascending=False, include_self=False):
513        """
514        Creates a ``QuerySet`` containing the ancestors of this model
515        instance.
516
517        This defaults to being in descending order (root ancestor first,
518        immediate parent last); passing ``True`` for the ``ascending``
519        argument will reverse the ordering (immediate parent first, root
520        ancestor last).
521
522        If ``include_self`` is ``True``, the ``QuerySet`` will also
523        include this model instance.
524        """
525        opts = self._mptt_meta
526        if self.is_root_node():
527            if not include_self:
528                return self._tree_manager.none()
529            else:
530                # Filter on pk for efficiency.
531                qs = self._tree_manager.filter(pk=self.pk)
532        else:
533            order_by = opts.left_attr
534            if ascending:
535                order_by = "-" + order_by
536
537            left = getattr(self, opts.left_attr)
538            right = getattr(self, opts.right_attr)
539
540            if not include_self:
541                left -= 1
542                right += 1
543
544            qs = self._tree_manager._mptt_filter(
545                left__lte=left,
546                right__gte=right,
547                tree_id=self._mpttfield("tree_id"),
548            )
549
550            qs = qs.order_by(order_by)
551
552        if hasattr(self, "_mptt_use_cached_ancestors"):
553            # Called during or after a `recursetree` tag.
554            # There should be cached parents up to level 0.
555            # So we can use them to avoid doing a query at all.
556            ancestors = []
557            p = self
558            if not include_self:
559                p = getattr(p, opts.parent_attr)
560
561            while p is not None:
562                ancestors.append(p)
563                p = getattr(p, opts.parent_attr)
564
565            ancestors.reverse()
566            qs._result_cache = ancestors
567
568        return qs
569
570    @raise_if_unsaved
571    def get_family(self):
572        """
573        Returns a ``QuerySet`` containing the ancestors, the model itself
574        and the descendants, in tree order.
575        """
576        opts = self._mptt_meta
577
578        left = getattr(self, opts.left_attr)
579        right = getattr(self, opts.right_attr)
580
581        ancestors = Q(
582            **{
583                "%s__lte" % opts.left_attr: left,
584                "%s__gte" % opts.right_attr: right,
585                opts.tree_id_attr: self._mpttfield("tree_id"),
586            }
587        )
588
589        descendants = Q(
590            **{
591                "%s__gte" % opts.left_attr: left,
592                "%s__lte" % opts.left_attr: right,
593                opts.tree_id_attr: self._mpttfield("tree_id"),
594            }
595        )
596
597        return self._tree_manager.filter(ancestors | descendants)
598
599    @raise_if_unsaved
600    def get_children(self):
601        """
602        Returns a ``QuerySet`` containing the immediate children of this
603        model instance, in tree order.
604
605        The benefit of using this method over the reverse relation
606        provided by the ORM to the instance's children is that a
607        database query can be avoided in the case where the instance is
608        a leaf node (it has no children).
609
610        If called from a template where the tree has been walked by the
611        ``cache_tree_children`` filter, no database query is required.
612        """
613        if hasattr(self, "_cached_children"):
614            qs = self._tree_manager.filter(pk__in=[n.pk for n in self._cached_children])
615            qs._result_cache = self._cached_children
616            return qs
617        else:
618            if self.is_leaf_node():
619                return self._tree_manager.none()
620
621            return self._tree_manager._mptt_filter(parent=self)
622
623    @raise_if_unsaved
624    def get_descendants(self, include_self=False):
625        """
626        Creates a ``QuerySet`` containing descendants of this model
627        instance, in tree order.
628
629        If ``include_self`` is ``True``, the ``QuerySet`` will also
630        include this model instance.
631        """
632        if self.is_leaf_node():
633            if not include_self:
634                return self._tree_manager.none()
635            else:
636                return self._tree_manager.filter(pk=self.pk)
637
638        opts = self._mptt_meta
639        left = getattr(self, opts.left_attr)
640        right = getattr(self, opts.right_attr)
641
642        if not include_self:
643            left += 1
644            right -= 1
645
646        return self._tree_manager._mptt_filter(
647            tree_id=self._mpttfield("tree_id"), left__gte=left, left__lte=right
648        )
649
650    def get_descendant_count(self):
651        """
652        Returns the number of descendants this model instance has.
653        """
654        if self._mpttfield("right") is None:
655            # node not saved yet
656            return 0
657        else:
658            return (self._mpttfield("right") - self._mpttfield("left") - 1) // 2
659
660    @raise_if_unsaved
661    def get_leafnodes(self, include_self=False):
662        """
663        Creates a ``QuerySet`` containing leafnodes of this model
664        instance, in tree order.
665
666        If ``include_self`` is ``True``, the ``QuerySet`` will also
667        include this model instance (if it is a leaf node)
668        """
669        descendants = self.get_descendants(include_self=include_self)
670
671        return self._tree_manager._mptt_filter(
672            descendants, left=(models.F(self._mptt_meta.right_attr) - 1)
673        )
674
675    @raise_if_unsaved
676    def get_next_sibling(self, *filter_args, **filter_kwargs):
677        """
678        Returns this model instance's next sibling in the tree, or
679        ``None`` if it doesn't have a next sibling.
680        """
681        qs = self._tree_manager.filter(*filter_args, **filter_kwargs)
682        if self.is_root_node():
683            qs = self._tree_manager._mptt_filter(
684                qs,
685                parent=None,
686                tree_id__gt=self._mpttfield("tree_id"),
687            )
688        else:
689            qs = self._tree_manager._mptt_filter(
690                qs,
691                parent__pk=getattr(self, self._mptt_meta.parent_attr + "_id"),
692                left__gt=self._mpttfield("right"),
693            )
694
695        siblings = qs[:1]
696        return siblings and siblings[0] or None
697
698    @raise_if_unsaved
699    def get_previous_sibling(self, *filter_args, **filter_kwargs):
700        """
701        Returns this model instance's previous sibling in the tree, or
702        ``None`` if it doesn't have a previous sibling.
703        """
704        opts = self._mptt_meta
705        qs = self._tree_manager.filter(*filter_args, **filter_kwargs)
706        if self.is_root_node():
707            qs = self._tree_manager._mptt_filter(
708                qs,
709                parent=None,
710                tree_id__lt=self._mpttfield("tree_id"),
711            )
712            qs = qs.order_by("-" + opts.tree_id_attr)
713        else:
714            qs = self._tree_manager._mptt_filter(
715                qs,
716                parent__pk=getattr(self, opts.parent_attr + "_id"),
717                right__lt=self._mpttfield("left"),
718            )
719            qs = qs.order_by("-" + opts.right_attr)
720
721        siblings = qs[:1]
722        return siblings and siblings[0] or None
723
724    @raise_if_unsaved
725    def get_root(self):
726        """
727        Returns the root node of this model instance's tree.
728        """
729        if self.is_root_node() and type(self) == self._tree_manager.tree_model:
730            return self
731
732        return self._tree_manager._mptt_filter(
733            tree_id=self._mpttfield("tree_id"),
734            parent=None,
735        ).get()
736
737    @raise_if_unsaved
738    def get_siblings(self, include_self=False):
739        """
740        Creates a ``QuerySet`` containing siblings of this model
741        instance. Root nodes are considered to be siblings of other root
742        nodes.
743
744        If ``include_self`` is ``True``, the ``QuerySet`` will also
745        include this model instance.
746        """
747        if self.is_root_node():
748            queryset = self._tree_manager._mptt_filter(parent=None)
749        else:
750            parent_id = getattr(self, self._mptt_meta.parent_attr + "_id")
751            queryset = self._tree_manager._mptt_filter(parent__pk=parent_id)
752        if not include_self:
753            queryset = queryset.exclude(pk=self.pk)
754        return queryset
755
756    def get_level(self):
757        """
758        Returns the level of this node (distance from root)
759        """
760        return getattr(self, self._mptt_meta.level_attr)
761
762    def insert_at(
763        self,
764        target,
765        position="first-child",
766        save=False,
767        allow_existing_pk=False,
768        refresh_target=True,
769    ):
770        """
771        Convenience method for calling ``TreeManager.insert_node`` with this
772        model instance.
773        """
774        self._tree_manager.insert_node(
775            self,
776            target,
777            position,
778            save,
779            allow_existing_pk=allow_existing_pk,
780            refresh_target=refresh_target,
781        )
782
783    def is_child_node(self):
784        """
785        Returns ``True`` if this model instance is a child node, ``False``
786        otherwise.
787        """
788        return not self.is_root_node()
789
790    def is_leaf_node(self):
791        """
792        Returns ``True`` if this model instance is a leaf node (it has no
793        children), ``False`` otherwise.
794        """
795        return not self.get_descendant_count()
796
797    def is_root_node(self):
798        """
799        Returns ``True`` if this model instance is a root node,
800        ``False`` otherwise.
801        """
802        return getattr(self, self._mptt_meta.parent_attr + "_id") is None
803
804    @raise_if_unsaved
805    def is_descendant_of(self, other, include_self=False):
806        """
807        Returns ``True`` if this model is a descendant of the given node,
808        ``False`` otherwise.
809        If include_self is True, also returns True if the two nodes are the same node.
810        """
811        opts = self._mptt_meta
812
813        if include_self and other.pk == self.pk:
814            return True
815
816        if getattr(self, opts.tree_id_attr) != getattr(other, opts.tree_id_attr):
817            return False
818        else:
819            left = getattr(self, opts.left_attr)
820            right = getattr(self, opts.right_attr)
821
822            return left > getattr(other, opts.left_attr) and right < getattr(
823                other, opts.right_attr
824            )
825
826    @raise_if_unsaved
827    def is_ancestor_of(self, other, include_self=False):
828        """
829        Returns ``True`` if this model is an ancestor of the given node,
830        ``False`` otherwise.
831        If include_self is True, also returns True if the two nodes are the same node.
832        """
833        if include_self and other.pk == self.pk:
834            return True
835        return other.is_descendant_of(self)
836
837    def move_to(self, target, position="first-child"):
838        """
839        Convenience method for calling ``TreeManager.move_node`` with this
840        model instance.
841
842        NOTE: This is a low-level method; it does NOT respect ``MPTTMeta.order_insertion_by``.
843        In most cases you should just move the node yourself by setting node.parent.
844        """
845        self._tree_manager.move_node(self, target, position)
846
847    def _is_saved(self, using=None):
848        if self.pk is None or self._mpttfield("tree_id") is None:
849            return False
850        opts = self._meta
851        if opts.pk.remote_field is None:
852            return True
853        else:
854            if not hasattr(self, "_mptt_saved"):
855                manager = self.__class__._base_manager
856                manager = manager.using(using)
857                self._mptt_saved = manager.filter(pk=self.pk).exists()
858            return self._mptt_saved
859
860    def _get_user_field_names(self):
861        """Returns the list of user defined (i.e. non-mptt internal) field names."""
862        from django.db.models.fields import AutoField
863
864        field_names = []
865        internal_fields = (
866            self._mptt_meta.left_attr,
867            self._mptt_meta.right_attr,
868            self._mptt_meta.tree_id_attr,
869            self._mptt_meta.level_attr,
870        )
871        for field in self._meta.concrete_fields:
872            if (
873                (field.name not in internal_fields)
874                and (not isinstance(field, AutoField))
875                and (not field.primary_key)
876            ):  # noqa
877                field_names.append(field.name)
878        return field_names
879
880    def save(self, *args, **kwargs):
881        """
882        If this is a new node, sets tree fields up before it is inserted
883        into the database, making room in the tree structure as necessary,
884        defaulting to making the new node the last child of its parent.
885
886        It the node's left and right edge indicators already been set, we
887        take this as indication that the node has already been set up for
888        insertion, so its tree fields are left untouched.
889
890        If this is an existing node and its parent has been changed,
891        performs reparenting in the tree structure, defaulting to making the
892        node the last child of its new parent.
893
894        In either case, if the node's class has its ``order_insertion_by``
895        tree option set, the node will be inserted or moved to the
896        appropriate position to maintain ordering by the specified field.
897        """
898        do_updates = self.__class__._mptt_updates_enabled
899        track_updates = self.__class__._mptt_is_tracking
900
901        opts = self._mptt_meta
902
903        if not (do_updates or track_updates):
904            # inside manager.disable_mptt_updates(), don't do any updates.
905            # unless we're also inside TreeManager.delay_mptt_updates()
906            if self._mpttfield("left") is None:
907                # we need to set *some* values, though don't care too much what.
908                parent = cached_field_value(self, opts.parent_attr)
909                # if we have a cached parent, have a stab at getting
910                # possibly-correct values.  otherwise, meh.
911                if parent:
912                    left = parent._mpttfield("left") + 1
913                    setattr(self, opts.left_attr, left)
914                    setattr(self, opts.right_attr, left + 1)
915                    setattr(self, opts.level_attr, parent._mpttfield("level") + 1)
916                    setattr(self, opts.tree_id_attr, parent._mpttfield("tree_id"))
917                    self._tree_manager._post_insert_update_cached_parent_right(
918                        parent, 2
919                    )
920                else:
921                    setattr(self, opts.left_attr, 1)
922                    setattr(self, opts.right_attr, 2)
923                    setattr(self, opts.level_attr, 0)
924                    setattr(self, opts.tree_id_attr, 0)
925            return super().save(*args, **kwargs)
926
927        parent_id = opts.get_raw_field_value(self, opts.parent_attr)
928
929        # determine whether this instance is already in the db
930        force_update = kwargs.get("force_update", False)
931        force_insert = kwargs.get("force_insert", False)
932        collapse_old_tree = None
933        deferred_fields = self.get_deferred_fields()
934        if force_update or (
935            not force_insert and self._is_saved(using=kwargs.get("using"))
936        ):
937            # it already exists, so do a move
938            old_parent_id = self._mptt_cached_fields[opts.parent_attr]
939            if old_parent_id is DeferredAttribute:
940                same_order = True
941            else:
942                same_order = old_parent_id == parent_id
943
944            if same_order and len(self._mptt_cached_fields) > 1:
945                for field_name, old_value in self._mptt_cached_fields.items():
946                    if (
947                        old_value is DeferredAttribute
948                        and field_name not in deferred_fields
949                    ):
950                        same_order = False
951                        break
952                    if old_value != opts.get_raw_field_value(self, field_name):
953                        same_order = False
954                        break
955                if not do_updates and not same_order:
956                    same_order = True
957                    self.__class__._mptt_track_tree_modified(self._mpttfield("tree_id"))
958            elif (not do_updates) and not same_order and old_parent_id is None:
959                # the old tree no longer exists, so we need to collapse it.
960                collapse_old_tree = self._mpttfield("tree_id")
961                parent = getattr(self, opts.parent_attr)
962                tree_id = parent._mpttfield("tree_id")
963                left = parent._mpttfield("left") + 1
964                self.__class__._mptt_track_tree_modified(tree_id)
965                setattr(self, opts.tree_id_attr, tree_id)
966                setattr(self, opts.left_attr, left)
967                setattr(self, opts.right_attr, left + 1)
968                setattr(self, opts.level_attr, parent._mpttfield("level") + 1)
969                same_order = True
970
971            if not same_order:
972                parent = getattr(self, opts.parent_attr)
973                opts.set_raw_field_value(self, opts.parent_attr, old_parent_id)
974                try:
975                    right_sibling = opts.get_ordered_insertion_target(self, parent)
976
977                    if parent_id is not None:
978                        # If we aren't already a descendant of the new parent,
979                        # we need to update the parent.rght so things like
980                        # get_children and get_descendant_count work correctly.
981                        #
982                        # parent might be None if parent_id was assigned
983                        # directly -- then we certainly do not have to update
984                        # the cached parent.
985                        update_cached_parent = parent and (
986                            getattr(self, opts.tree_id_attr)
987                            != getattr(parent, opts.tree_id_attr)
988                            or getattr(self, opts.left_attr)  # noqa
989                            < getattr(parent, opts.left_attr)
990                            or getattr(self, opts.right_attr)
991                            > getattr(parent, opts.right_attr)
992                        )
993
994                    if right_sibling:
995                        self._tree_manager._move_node(
996                            self,
997                            right_sibling,
998                            "left",
999                            save=False,
1000                            refresh_target=False,
1001                        )
1002                    else:
1003                        # Default movement
1004                        if parent_id is None:
1005                            root_nodes = self._tree_manager.root_nodes()
1006                            try:
1007                                rightmost_sibling = root_nodes.exclude(
1008                                    pk=self.pk
1009                                ).order_by("-" + opts.tree_id_attr)[0]
1010                                self._tree_manager._move_node(
1011                                    self,
1012                                    rightmost_sibling,
1013                                    "right",
1014                                    save=False,
1015                                    refresh_target=False,
1016                                )
1017                            except IndexError:
1018                                pass
1019                        else:
1020                            self._tree_manager._move_node(
1021                                self, parent, "last-child", save=False
1022                            )
1023
1024                    if parent_id is not None and update_cached_parent:
1025                        # Update rght of cached parent
1026                        right_shift = 2 * (self.get_descendant_count() + 1)
1027                        self._tree_manager._post_insert_update_cached_parent_right(
1028                            parent, right_shift
1029                        )
1030                finally:
1031                    # Make sure the new parent is always
1032                    # restored on the way out in case of errors.
1033                    opts.set_raw_field_value(self, opts.parent_attr, parent_id)
1034
1035                # If there were no exceptions raised then send a moved signal
1036                node_moved.send(
1037                    sender=self.__class__,
1038                    instance=self,
1039                    target=getattr(self, opts.parent_attr),
1040                )
1041            else:
1042                opts.set_raw_field_value(self, opts.parent_attr, parent_id)
1043                if not track_updates:
1044                    # When not using delayed/disabled updates,
1045                    # populate update_fields with user defined model fields.
1046                    # This helps preserve tree integrity when saving model on top
1047                    # of a modified tree.
1048                    if len(args) > 3:
1049                        if not args[3]:
1050                            args = list(args)
1051                            args[3] = self._get_user_field_names()
1052                            args = tuple(args)
1053                    else:
1054                        if not kwargs.get("update_fields", None):
1055                            kwargs["update_fields"] = self._get_user_field_names()
1056
1057        else:
1058            # new node, do an insert
1059            if getattr(self, opts.left_attr) and getattr(self, opts.right_attr):
1060                # This node has already been set up for insertion.
1061                pass
1062            else:
1063                parent = getattr(self, opts.parent_attr)
1064
1065                right_sibling = None
1066                # if we're inside delay_mptt_updates, don't do queries to find
1067                # sibling position.  instead, do default insertion. correct
1068                # positions will be found during partial rebuild later.
1069                # *unless* this is a root node. (as update tracking doesn't
1070                # handle re-ordering of trees.)
1071                if do_updates or parent is None:
1072                    if opts.order_insertion_by:
1073                        right_sibling = opts.get_ordered_insertion_target(self, parent)
1074
1075                if right_sibling:
1076                    self.insert_at(
1077                        right_sibling,
1078                        "left",
1079                        allow_existing_pk=True,
1080                        refresh_target=False,
1081                    )
1082
1083                    if parent:
1084                        # since we didn't insert into parent, we have to update parent.rght
1085                        # here instead of in TreeManager.insert_node()
1086                        right_shift = 2 * (self.get_descendant_count() + 1)
1087                        self._tree_manager._post_insert_update_cached_parent_right(
1088                            parent, right_shift
1089                        )
1090                else:
1091                    # Default insertion
1092                    self.insert_at(
1093                        parent, position="last-child", allow_existing_pk=True
1094                    )
1095        try:
1096            super().save(*args, **kwargs)
1097        finally:
1098            if collapse_old_tree is not None:
1099                self._tree_manager._create_tree_space(collapse_old_tree, -1)
1100
1101        self._mptt_saved = True
1102        opts.update_mptt_cached_fields(self)
1103
1104    save.alters_data = True
1105
1106    def delete(self, *args, **kwargs):
1107        """Calling ``delete`` on a node will delete it as well as its full
1108        subtree, as opposed to reattaching all the subnodes to its parent node.
1109
1110        There are no argument specific to a MPTT model, all the arguments will
1111        be passed directly to the django's ``Model.delete``.
1112
1113        ``delete`` will not return anything."""
1114        try:
1115            # We have to make sure we use database's mptt values, since they
1116            # could have changed between the moment the instance was retrieved and
1117            # the moment it is deleted.
1118            # This happens for example if you delete several nodes at once from a queryset.
1119            fields_to_refresh = [
1120                self._mptt_meta.right_attr,
1121                self._mptt_meta.left_attr,
1122                self._mptt_meta.tree_id_attr,
1123            ]
1124            self.refresh_from_db(fields=fields_to_refresh)
1125        except self.__class__.DoesNotExist:
1126            # In case the object was already deleted, we don't want to throw an exception
1127            pass
1128        tree_width = self._mpttfield("right") - self._mpttfield("left") + 1
1129        target_right = self._mpttfield("right")
1130        tree_id = self._mpttfield("tree_id")
1131        self._tree_manager._close_gap(tree_width, target_right, tree_id)
1132        parent = cached_field_value(self, self._mptt_meta.parent_attr)
1133        if parent:
1134            right_shift = -self.get_descendant_count() - 2
1135            self._tree_manager._post_insert_update_cached_parent_right(
1136                parent, right_shift
1137            )
1138
1139        return super().delete(*args, **kwargs)
1140
1141    delete.alters_data = True
1142
1143    def _mptt_refresh(self):
1144        if not self.pk:
1145            return
1146        manager = type(self)._tree_manager
1147        opts = self._mptt_meta
1148        values = (
1149            manager.using(self._state.db)
1150            .filter(pk=self.pk)
1151            .values(
1152                opts.left_attr,
1153                opts.right_attr,
1154                opts.level_attr,
1155                opts.tree_id_attr,
1156            )[0]
1157        )
1158        for k, v in values.items():
1159            setattr(self, k, v)
1160