1import operator 2import threading 3from functools import reduce, wraps 4 5from django.db import models 6from django.db.models.base import ModelBase 7from django.db.models.query import Q 8from django.db.models.query_utils import DeferredAttribute 9from django.utils.translation import gettext as _ 10 11from mptt.compat import cached_field_value 12from mptt.fields import TreeForeignKey, TreeManyToManyField, TreeOneToOneField 13from mptt.managers import TreeManager 14from mptt.signals import node_moved 15from mptt.utils import _get_tree_model 16 17 18__all__ = ( 19 "TreeForeignKey", 20 "TreeOneToOneField", 21 "TreeManyToManyField", 22 "TreeManager", 23 "MPTTOptions", 24 "MPTTModelBase", 25 "MPTTModel", 26) 27 28 29class _classproperty: 30 def __init__(self, getter, setter=None): 31 self.fget = getter 32 self.fset = setter 33 34 def __get__(self, cls, owner): 35 return self.fget(owner) 36 37 def __set__(self, cls, owner, value): 38 if not self.fset: 39 raise AttributeError("This classproperty is read only") 40 self.fset(owner, value) 41 42 43class classpropertytype(property): 44 def __init__(self, name, bases=(), members={}): 45 return super().__init__( 46 members.get("__get__"), 47 members.get("__set__"), 48 members.get("__delete__"), 49 members.get("__doc__"), 50 ) 51 52 53classproperty = classpropertytype("classproperty") 54 55 56class MPTTOptions: 57 """ 58 Options class for MPTT models. Use this as an inner class called ``MPTTMeta``:: 59 60 class MyModel(MPTTModel): 61 class MPTTMeta: 62 order_insertion_by = ['name'] 63 parent_attr = 'myparent' 64 """ 65 66 order_insertion_by = [] 67 left_attr = "lft" 68 right_attr = "rght" 69 tree_id_attr = "tree_id" 70 level_attr = "level" 71 parent_attr = "parent" 72 73 def __init__(self, opts=None, **kwargs): 74 # Override defaults with options provided 75 if opts: 76 opts = list(opts.__dict__.items()) 77 else: 78 opts = [] 79 opts.extend(list(kwargs.items())) 80 81 if "tree_manager_attr" in [opt[0] for opt in opts]: 82 raise ValueError( 83 "`tree_manager_attr` has been removed; you should instantiate" 84 " a TreeManager as a normal manager on your model instead." 85 ) 86 87 for key, value in opts: 88 if key[:2] == "__": 89 continue 90 setattr(self, key, value) 91 92 # Normalize order_insertion_by to a list 93 if isinstance(self.order_insertion_by, str): 94 self.order_insertion_by = [self.order_insertion_by] 95 elif isinstance(self.order_insertion_by, tuple): 96 self.order_insertion_by = list(self.order_insertion_by) 97 elif self.order_insertion_by is None: 98 self.order_insertion_by = [] 99 100 def __iter__(self): 101 return ((k, v) for k, v in self.__dict__.items() if k[0] != "_") 102 103 # Helper methods for accessing tree attributes on models. 104 def get_raw_field_value(self, instance, field_name): 105 """ 106 Gets the value of the given fieldname for the instance. 107 This is not the same as getattr(). 108 This function will return IDs for foreignkeys etc, rather than doing 109 a database query. 110 """ 111 field = instance._meta.get_field(field_name) 112 return field.value_from_object(instance) 113 114 def set_raw_field_value(self, instance, field_name, value): 115 """ 116 Sets the value of the given fieldname for the instance. 117 This is not the same as setattr(). 118 This function requires an ID for a foreignkey (etc) rather than an instance. 119 """ 120 field = instance._meta.get_field(field_name) 121 setattr(instance, field.attname, value) 122 123 def update_mptt_cached_fields(self, instance): 124 """ 125 Caches (in an instance._mptt_cached_fields dict) the original values of: 126 - parent pk 127 - fields specified in order_insertion_by 128 129 These are used in save() to determine if the relevant fields have changed, 130 so that the MPTT fields need to be updated. 131 """ 132 instance._mptt_cached_fields = {} 133 field_names = {self.parent_attr} 134 if self.order_insertion_by: 135 for f in self.order_insertion_by: 136 if f[0] == "-": 137 f = f[1:] 138 field_names.add(f) 139 deferred_fields = instance.get_deferred_fields() 140 for field_name in field_names: 141 if deferred_fields: 142 field = instance._meta.get_field(field_name) 143 if ( 144 field.attname in deferred_fields 145 and field.attname not in instance.__dict__ 146 ): 147 # deferred attribute (i.e. via .only() or .defer()) 148 # It'd be silly to cache this (that'd do a database query) 149 # Instead, we mark it as a deferred attribute here, then 150 # assume it hasn't changed during save(), unless it's no 151 # longer deferred. 152 instance._mptt_cached_fields[field_name] = DeferredAttribute 153 continue 154 instance._mptt_cached_fields[field_name] = self.get_raw_field_value( 155 instance, field_name 156 ) 157 158 def insertion_target_filters(self, instance, order_insertion_by): 159 """ 160 Creates a filter which matches suitable right siblings for ``node``, 161 where insertion should maintain ordering according to the list of 162 fields in ``order_insertion_by``. 163 164 For example, given an ``order_insertion_by`` of 165 ``['field1', 'field2', 'field3']``, the resulting filter should 166 correspond to the following SQL:: 167 168 field1 > %s 169 OR (field1 = %s AND field2 > %s) 170 OR (field1 = %s AND field2 = %s AND field3 > %s) 171 172 """ 173 fields = [] 174 filters = [] 175 fields__append = fields.append 176 filters__append = filters.append 177 and_ = operator.and_ 178 or_ = operator.or_ 179 for field_name in order_insertion_by: 180 if field_name[0] == "-": 181 field_name = field_name[1:] 182 filter_suffix = "__lt" 183 else: 184 filter_suffix = "__gt" 185 value = getattr(instance, field_name) 186 if value is None: 187 # node isn't saved yet. get the insertion value from pre_save. 188 field = instance._meta.get_field(field_name) 189 value = field.pre_save(instance, True) 190 191 if value is None: 192 # we have to use __isnull instead of __lt or __gt becase __lt = Null is invalid 193 # depending on order, we need to find the first node where code is null or not null 194 value = filter_suffix == "__lt" 195 filter_suffix = "__isnull" 196 197 q = Q(**{field_name + filter_suffix: value}) 198 199 filters__append(reduce(and_, [Q(**{f: v}) for f, v in fields] + [q])) 200 fields__append((field_name, value)) 201 return reduce(or_, filters) 202 203 def get_ordered_insertion_target(self, node, parent): 204 """ 205 Attempts to retrieve a suitable right sibling for ``node`` 206 underneath ``parent`` (which may be ``None`` in the case of root 207 nodes) so that ordering by the fields specified by the node's class' 208 ``order_insertion_by`` option is maintained. 209 210 Returns ``None`` if no suitable sibling can be found. 211 """ 212 right_sibling = None 213 # Optimisation - if the parent doesn't have descendants, 214 # the node will always be its last child. 215 if self.order_insertion_by and ( 216 parent is None or parent.get_descendant_count() > 0 217 ): 218 opts = node._mptt_meta 219 order_by = opts.order_insertion_by[:] 220 filters = self.insertion_target_filters(node, order_by) 221 if parent: 222 filters = filters & Q(**{opts.parent_attr: parent}) 223 # Fall back on tree ordering if multiple child nodes have 224 # the same values. 225 order_by.append(opts.left_attr) 226 else: 227 filters = filters & Q(**{opts.parent_attr: None}) 228 # Fall back on tree id ordering if multiple root nodes have 229 # the same values. 230 order_by.append(opts.tree_id_attr) 231 queryset = ( 232 node.__class__._tree_manager.db_manager(node._state.db) 233 .filter(filters) 234 .order_by(*order_by) 235 ) 236 if node.pk: 237 queryset = queryset.exclude(pk=node.pk) 238 try: 239 right_sibling = queryset[:1][0] 240 except IndexError: 241 # No suitable right sibling could be found 242 pass 243 return right_sibling 244 245 246class MPTTModelBase(ModelBase): 247 """ 248 Metaclass for MPTT models 249 """ 250 251 def __new__(meta, class_name, bases, class_dict): 252 """ 253 Create subclasses of MPTTModel. This: 254 - adds the MPTT fields to the class 255 - adds a TreeManager to the model 256 """ 257 if class_name == "NewBase" and class_dict == {}: 258 return super().__new__(meta, class_name, bases, class_dict) 259 is_MPTTModel = False 260 try: 261 MPTTModel 262 except NameError: 263 is_MPTTModel = True 264 265 MPTTMeta = class_dict.pop("MPTTMeta", None) 266 if not MPTTMeta: 267 268 class MPTTMeta: 269 pass 270 271 initial_options = frozenset(dir(MPTTMeta)) 272 273 # extend MPTTMeta from base classes 274 for base in bases: 275 if hasattr(base, "_mptt_meta"): 276 for name, value in base._mptt_meta: 277 if name == "tree_manager_attr": 278 continue 279 if name not in initial_options: 280 setattr(MPTTMeta, name, value) 281 282 class_dict["_mptt_meta"] = MPTTOptions(MPTTMeta) 283 super_new = super().__new__ 284 cls = super_new(meta, class_name, bases, class_dict) 285 cls = meta.register(cls) 286 287 # see error cases in TreeManager.disable_mptt_updates for the reasoning here. 288 cls._mptt_tracking_base = None 289 if is_MPTTModel: 290 bases = [cls] 291 else: 292 bases = [base for base in cls.mro() if issubclass(base, MPTTModel)] 293 for base in bases: 294 if ( 295 not (base._meta.abstract or base._meta.proxy) 296 and base._tree_manager.tree_model is base 297 ): 298 cls._mptt_tracking_base = base 299 break 300 if cls is cls._mptt_tracking_base: 301 cls._threadlocal = threading.local() 302 # set on first access (to make threading errors more obvious): 303 # cls._threadlocal.mptt_delayed_tree_changes = None 304 305 return cls 306 307 @classmethod 308 def register(meta, cls, **kwargs): 309 """ 310 For the weird cases when you need to add tree-ness to an *existing* 311 class. For other cases you should subclass MPTTModel instead of calling this. 312 """ 313 314 if not issubclass(cls, models.Model): 315 raise ValueError(_("register() expects a Django model class argument")) 316 317 if not hasattr(cls, "_mptt_meta"): 318 cls._mptt_meta = MPTTOptions(**kwargs) 319 320 abstract = getattr(cls._meta, "abstract", False) 321 322 try: 323 MPTTModel 324 except NameError: 325 # We're defining the base class right now, so don't do anything 326 # We only want to add this stuff to the subclasses. 327 # (Otherwise if field names are customized, we'll end up adding two 328 # copies) 329 pass 330 else: 331 if not issubclass(cls, MPTTModel): 332 bases = list(cls.__bases__) 333 334 # strip out bases that are strict superclasses of MPTTModel. 335 # (i.e. Model, object) 336 # this helps linearize the type hierarchy if possible 337 for i in range(len(bases) - 1, -1, -1): 338 if issubclass(MPTTModel, bases[i]): 339 del bases[i] 340 341 bases.insert(0, MPTTModel) 342 cls.__bases__ = tuple(bases) 343 344 is_cls_tree_model = _get_tree_model(cls) is cls 345 346 if is_cls_tree_model: 347 # HACK: _meta.get_field() doesn't work before AppCache.ready in Django>=1.8 348 # ( see https://code.djangoproject.com/ticket/24231 ) 349 # So the only way to get existing fields is using local_fields on all superclasses. 350 existing_field_names = set() 351 for base in cls.mro(): 352 if hasattr(base, "_meta"): 353 existing_field_names.update( 354 [f.name for f in base._meta.local_fields] 355 ) 356 357 mptt_meta = cls._mptt_meta 358 indexed_attrs = (mptt_meta.tree_id_attr,) 359 field_names = ( 360 mptt_meta.left_attr, 361 mptt_meta.right_attr, 362 mptt_meta.tree_id_attr, 363 mptt_meta.level_attr, 364 ) 365 366 for field_name in field_names: 367 if field_name not in existing_field_names: 368 field = models.PositiveIntegerField( 369 db_index=field_name in indexed_attrs, editable=False 370 ) 371 field.contribute_to_class(cls, field_name) 372 373 # Add an index_together on tree_id_attr and left_attr, as these are very 374 # commonly queried (pretty much all reads). 375 index_together = (cls._mptt_meta.tree_id_attr, cls._mptt_meta.left_attr) 376 if index_together not in cls._meta.index_together: 377 cls._meta.index_together += (index_together,) 378 379 # Add a tree manager, if there isn't one already 380 if not abstract: 381 # make sure we have a tree manager somewhere 382 tree_manager = None 383 # Use the default manager defined on the class if any 384 if cls._default_manager and isinstance( 385 cls._default_manager, TreeManager 386 ): 387 tree_manager = cls._default_manager 388 else: 389 for cls_manager in cls._meta.managers: 390 if isinstance(cls_manager, TreeManager): 391 # prefer any locally defined manager (i.e. keep going if not local) 392 if cls_manager.model is cls: 393 tree_manager = cls_manager 394 break 395 396 if is_cls_tree_model: 397 idx_together = ( 398 cls._mptt_meta.tree_id_attr, 399 cls._mptt_meta.left_attr, 400 ) 401 402 if idx_together not in cls._meta.index_together: 403 cls._meta.index_together += (idx_together,) 404 405 if tree_manager and tree_manager.model is not cls: 406 tree_manager = tree_manager._copy_to_model(cls) 407 elif tree_manager is None: 408 tree_manager = TreeManager() 409 tree_manager.contribute_to_class(cls, "_tree_manager") 410 411 # avoid using ManagerDescriptor, so instances can refer to self._tree_manager 412 setattr(cls, "_tree_manager", tree_manager) 413 return cls 414 415 416def raise_if_unsaved(func): 417 @wraps(func) 418 def _fn(self, *args, **kwargs): 419 if self._state.adding: 420 raise ValueError( 421 "Cannot call %(function)s on unsaved %(class)s instances" 422 % {"function": func.__name__, "class": self.__class__.__name__} 423 ) 424 return func(self, *args, **kwargs) 425 426 return _fn 427 428 429class MPTTModel(models.Model, metaclass=MPTTModelBase): 430 """ 431 Base class for tree models. 432 """ 433 434 class Meta: 435 abstract = True 436 437 objects = TreeManager() 438 439 def __init__(self, *args, **kwargs): 440 super().__init__(*args, **kwargs) 441 self._mptt_meta.update_mptt_cached_fields(self) 442 443 def _mpttfield(self, fieldname): 444 translated_fieldname = getattr(self._mptt_meta, fieldname + "_attr") 445 return getattr(self, translated_fieldname) 446 447 @_classproperty 448 def _mptt_updates_enabled(cls): 449 if not cls._mptt_tracking_base: 450 return True 451 return getattr( 452 cls._mptt_tracking_base._threadlocal, "mptt_updates_enabled", True 453 ) 454 455 # ideally this'd be part of the _mptt_updates_enabled classproperty, but it seems 456 # that settable classproperties are very, very hard to do! suggestions please :) 457 @classmethod 458 def _set_mptt_updates_enabled(cls, value): 459 assert ( 460 cls is cls._mptt_tracking_base 461 ), "Can't enable or disable mptt updates on a non-tracking class." 462 cls._threadlocal.mptt_updates_enabled = value 463 464 @_classproperty 465 def _mptt_is_tracking(cls): 466 if not cls._mptt_tracking_base: 467 return False 468 if not hasattr(cls._threadlocal, "mptt_delayed_tree_changes"): 469 # happens the first time this is called from each thread 470 cls._threadlocal.mptt_delayed_tree_changes = None 471 return cls._threadlocal.mptt_delayed_tree_changes is not None 472 473 @classmethod 474 def _mptt_start_tracking(cls): 475 assert ( 476 cls is cls._mptt_tracking_base 477 ), "Can't start or stop mptt tracking on a non-tracking class." 478 assert not cls._mptt_is_tracking, "mptt tracking is already started." 479 cls._threadlocal.mptt_delayed_tree_changes = set() 480 481 @classmethod 482 def _mptt_stop_tracking(cls): 483 assert ( 484 cls is cls._mptt_tracking_base 485 ), "Can't start or stop mptt tracking on a non-tracking class." 486 assert cls._mptt_is_tracking, "mptt tracking isn't started." 487 results = cls._threadlocal.mptt_delayed_tree_changes 488 cls._threadlocal.mptt_delayed_tree_changes = None 489 return results 490 491 @classmethod 492 def _mptt_track_tree_modified(cls, tree_id): 493 if not cls._mptt_is_tracking: 494 return 495 cls._threadlocal.mptt_delayed_tree_changes.add(tree_id) 496 497 @classmethod 498 def _mptt_track_tree_insertions(cls, tree_id, num_inserted): 499 if not cls._mptt_is_tracking: 500 return 501 changes = cls._threadlocal.mptt_delayed_tree_changes 502 if not num_inserted or not changes: 503 return 504 505 if num_inserted < 0: 506 deleted = range(tree_id + num_inserted, -num_inserted) 507 changes.difference_update(deleted) 508 new_changes = {(t + num_inserted if t >= tree_id else t) for t in changes} 509 cls._threadlocal.mptt_delayed_tree_changes = new_changes 510 511 @raise_if_unsaved 512 def get_ancestors(self, ascending=False, include_self=False): 513 """ 514 Creates a ``QuerySet`` containing the ancestors of this model 515 instance. 516 517 This defaults to being in descending order (root ancestor first, 518 immediate parent last); passing ``True`` for the ``ascending`` 519 argument will reverse the ordering (immediate parent first, root 520 ancestor last). 521 522 If ``include_self`` is ``True``, the ``QuerySet`` will also 523 include this model instance. 524 """ 525 opts = self._mptt_meta 526 if self.is_root_node(): 527 if not include_self: 528 return self._tree_manager.none() 529 else: 530 # Filter on pk for efficiency. 531 qs = self._tree_manager.filter(pk=self.pk) 532 else: 533 order_by = opts.left_attr 534 if ascending: 535 order_by = "-" + order_by 536 537 left = getattr(self, opts.left_attr) 538 right = getattr(self, opts.right_attr) 539 540 if not include_self: 541 left -= 1 542 right += 1 543 544 qs = self._tree_manager._mptt_filter( 545 left__lte=left, 546 right__gte=right, 547 tree_id=self._mpttfield("tree_id"), 548 ) 549 550 qs = qs.order_by(order_by) 551 552 if hasattr(self, "_mptt_use_cached_ancestors"): 553 # Called during or after a `recursetree` tag. 554 # There should be cached parents up to level 0. 555 # So we can use them to avoid doing a query at all. 556 ancestors = [] 557 p = self 558 if not include_self: 559 p = getattr(p, opts.parent_attr) 560 561 while p is not None: 562 ancestors.append(p) 563 p = getattr(p, opts.parent_attr) 564 565 ancestors.reverse() 566 qs._result_cache = ancestors 567 568 return qs 569 570 @raise_if_unsaved 571 def get_family(self): 572 """ 573 Returns a ``QuerySet`` containing the ancestors, the model itself 574 and the descendants, in tree order. 575 """ 576 opts = self._mptt_meta 577 578 left = getattr(self, opts.left_attr) 579 right = getattr(self, opts.right_attr) 580 581 ancestors = Q( 582 **{ 583 "%s__lte" % opts.left_attr: left, 584 "%s__gte" % opts.right_attr: right, 585 opts.tree_id_attr: self._mpttfield("tree_id"), 586 } 587 ) 588 589 descendants = Q( 590 **{ 591 "%s__gte" % opts.left_attr: left, 592 "%s__lte" % opts.left_attr: right, 593 opts.tree_id_attr: self._mpttfield("tree_id"), 594 } 595 ) 596 597 return self._tree_manager.filter(ancestors | descendants) 598 599 @raise_if_unsaved 600 def get_children(self): 601 """ 602 Returns a ``QuerySet`` containing the immediate children of this 603 model instance, in tree order. 604 605 The benefit of using this method over the reverse relation 606 provided by the ORM to the instance's children is that a 607 database query can be avoided in the case where the instance is 608 a leaf node (it has no children). 609 610 If called from a template where the tree has been walked by the 611 ``cache_tree_children`` filter, no database query is required. 612 """ 613 if hasattr(self, "_cached_children"): 614 qs = self._tree_manager.filter(pk__in=[n.pk for n in self._cached_children]) 615 qs._result_cache = self._cached_children 616 return qs 617 else: 618 if self.is_leaf_node(): 619 return self._tree_manager.none() 620 621 return self._tree_manager._mptt_filter(parent=self) 622 623 @raise_if_unsaved 624 def get_descendants(self, include_self=False): 625 """ 626 Creates a ``QuerySet`` containing descendants of this model 627 instance, in tree order. 628 629 If ``include_self`` is ``True``, the ``QuerySet`` will also 630 include this model instance. 631 """ 632 if self.is_leaf_node(): 633 if not include_self: 634 return self._tree_manager.none() 635 else: 636 return self._tree_manager.filter(pk=self.pk) 637 638 opts = self._mptt_meta 639 left = getattr(self, opts.left_attr) 640 right = getattr(self, opts.right_attr) 641 642 if not include_self: 643 left += 1 644 right -= 1 645 646 return self._tree_manager._mptt_filter( 647 tree_id=self._mpttfield("tree_id"), left__gte=left, left__lte=right 648 ) 649 650 def get_descendant_count(self): 651 """ 652 Returns the number of descendants this model instance has. 653 """ 654 if self._mpttfield("right") is None: 655 # node not saved yet 656 return 0 657 else: 658 return (self._mpttfield("right") - self._mpttfield("left") - 1) // 2 659 660 @raise_if_unsaved 661 def get_leafnodes(self, include_self=False): 662 """ 663 Creates a ``QuerySet`` containing leafnodes of this model 664 instance, in tree order. 665 666 If ``include_self`` is ``True``, the ``QuerySet`` will also 667 include this model instance (if it is a leaf node) 668 """ 669 descendants = self.get_descendants(include_self=include_self) 670 671 return self._tree_manager._mptt_filter( 672 descendants, left=(models.F(self._mptt_meta.right_attr) - 1) 673 ) 674 675 @raise_if_unsaved 676 def get_next_sibling(self, *filter_args, **filter_kwargs): 677 """ 678 Returns this model instance's next sibling in the tree, or 679 ``None`` if it doesn't have a next sibling. 680 """ 681 qs = self._tree_manager.filter(*filter_args, **filter_kwargs) 682 if self.is_root_node(): 683 qs = self._tree_manager._mptt_filter( 684 qs, 685 parent=None, 686 tree_id__gt=self._mpttfield("tree_id"), 687 ) 688 else: 689 qs = self._tree_manager._mptt_filter( 690 qs, 691 parent__pk=getattr(self, self._mptt_meta.parent_attr + "_id"), 692 left__gt=self._mpttfield("right"), 693 ) 694 695 siblings = qs[:1] 696 return siblings and siblings[0] or None 697 698 @raise_if_unsaved 699 def get_previous_sibling(self, *filter_args, **filter_kwargs): 700 """ 701 Returns this model instance's previous sibling in the tree, or 702 ``None`` if it doesn't have a previous sibling. 703 """ 704 opts = self._mptt_meta 705 qs = self._tree_manager.filter(*filter_args, **filter_kwargs) 706 if self.is_root_node(): 707 qs = self._tree_manager._mptt_filter( 708 qs, 709 parent=None, 710 tree_id__lt=self._mpttfield("tree_id"), 711 ) 712 qs = qs.order_by("-" + opts.tree_id_attr) 713 else: 714 qs = self._tree_manager._mptt_filter( 715 qs, 716 parent__pk=getattr(self, opts.parent_attr + "_id"), 717 right__lt=self._mpttfield("left"), 718 ) 719 qs = qs.order_by("-" + opts.right_attr) 720 721 siblings = qs[:1] 722 return siblings and siblings[0] or None 723 724 @raise_if_unsaved 725 def get_root(self): 726 """ 727 Returns the root node of this model instance's tree. 728 """ 729 if self.is_root_node() and type(self) == self._tree_manager.tree_model: 730 return self 731 732 return self._tree_manager._mptt_filter( 733 tree_id=self._mpttfield("tree_id"), 734 parent=None, 735 ).get() 736 737 @raise_if_unsaved 738 def get_siblings(self, include_self=False): 739 """ 740 Creates a ``QuerySet`` containing siblings of this model 741 instance. Root nodes are considered to be siblings of other root 742 nodes. 743 744 If ``include_self`` is ``True``, the ``QuerySet`` will also 745 include this model instance. 746 """ 747 if self.is_root_node(): 748 queryset = self._tree_manager._mptt_filter(parent=None) 749 else: 750 parent_id = getattr(self, self._mptt_meta.parent_attr + "_id") 751 queryset = self._tree_manager._mptt_filter(parent__pk=parent_id) 752 if not include_self: 753 queryset = queryset.exclude(pk=self.pk) 754 return queryset 755 756 def get_level(self): 757 """ 758 Returns the level of this node (distance from root) 759 """ 760 return getattr(self, self._mptt_meta.level_attr) 761 762 def insert_at( 763 self, 764 target, 765 position="first-child", 766 save=False, 767 allow_existing_pk=False, 768 refresh_target=True, 769 ): 770 """ 771 Convenience method for calling ``TreeManager.insert_node`` with this 772 model instance. 773 """ 774 self._tree_manager.insert_node( 775 self, 776 target, 777 position, 778 save, 779 allow_existing_pk=allow_existing_pk, 780 refresh_target=refresh_target, 781 ) 782 783 def is_child_node(self): 784 """ 785 Returns ``True`` if this model instance is a child node, ``False`` 786 otherwise. 787 """ 788 return not self.is_root_node() 789 790 def is_leaf_node(self): 791 """ 792 Returns ``True`` if this model instance is a leaf node (it has no 793 children), ``False`` otherwise. 794 """ 795 return not self.get_descendant_count() 796 797 def is_root_node(self): 798 """ 799 Returns ``True`` if this model instance is a root node, 800 ``False`` otherwise. 801 """ 802 return getattr(self, self._mptt_meta.parent_attr + "_id") is None 803 804 @raise_if_unsaved 805 def is_descendant_of(self, other, include_self=False): 806 """ 807 Returns ``True`` if this model is a descendant of the given node, 808 ``False`` otherwise. 809 If include_self is True, also returns True if the two nodes are the same node. 810 """ 811 opts = self._mptt_meta 812 813 if include_self and other.pk == self.pk: 814 return True 815 816 if getattr(self, opts.tree_id_attr) != getattr(other, opts.tree_id_attr): 817 return False 818 else: 819 left = getattr(self, opts.left_attr) 820 right = getattr(self, opts.right_attr) 821 822 return left > getattr(other, opts.left_attr) and right < getattr( 823 other, opts.right_attr 824 ) 825 826 @raise_if_unsaved 827 def is_ancestor_of(self, other, include_self=False): 828 """ 829 Returns ``True`` if this model is an ancestor of the given node, 830 ``False`` otherwise. 831 If include_self is True, also returns True if the two nodes are the same node. 832 """ 833 if include_self and other.pk == self.pk: 834 return True 835 return other.is_descendant_of(self) 836 837 def move_to(self, target, position="first-child"): 838 """ 839 Convenience method for calling ``TreeManager.move_node`` with this 840 model instance. 841 842 NOTE: This is a low-level method; it does NOT respect ``MPTTMeta.order_insertion_by``. 843 In most cases you should just move the node yourself by setting node.parent. 844 """ 845 self._tree_manager.move_node(self, target, position) 846 847 def _is_saved(self, using=None): 848 if self.pk is None or self._mpttfield("tree_id") is None: 849 return False 850 opts = self._meta 851 if opts.pk.remote_field is None: 852 return True 853 else: 854 if not hasattr(self, "_mptt_saved"): 855 manager = self.__class__._base_manager 856 manager = manager.using(using) 857 self._mptt_saved = manager.filter(pk=self.pk).exists() 858 return self._mptt_saved 859 860 def _get_user_field_names(self): 861 """Returns the list of user defined (i.e. non-mptt internal) field names.""" 862 from django.db.models.fields import AutoField 863 864 field_names = [] 865 internal_fields = ( 866 self._mptt_meta.left_attr, 867 self._mptt_meta.right_attr, 868 self._mptt_meta.tree_id_attr, 869 self._mptt_meta.level_attr, 870 ) 871 for field in self._meta.concrete_fields: 872 if ( 873 (field.name not in internal_fields) 874 and (not isinstance(field, AutoField)) 875 and (not field.primary_key) 876 ): # noqa 877 field_names.append(field.name) 878 return field_names 879 880 def save(self, *args, **kwargs): 881 """ 882 If this is a new node, sets tree fields up before it is inserted 883 into the database, making room in the tree structure as necessary, 884 defaulting to making the new node the last child of its parent. 885 886 It the node's left and right edge indicators already been set, we 887 take this as indication that the node has already been set up for 888 insertion, so its tree fields are left untouched. 889 890 If this is an existing node and its parent has been changed, 891 performs reparenting in the tree structure, defaulting to making the 892 node the last child of its new parent. 893 894 In either case, if the node's class has its ``order_insertion_by`` 895 tree option set, the node will be inserted or moved to the 896 appropriate position to maintain ordering by the specified field. 897 """ 898 do_updates = self.__class__._mptt_updates_enabled 899 track_updates = self.__class__._mptt_is_tracking 900 901 opts = self._mptt_meta 902 903 if not (do_updates or track_updates): 904 # inside manager.disable_mptt_updates(), don't do any updates. 905 # unless we're also inside TreeManager.delay_mptt_updates() 906 if self._mpttfield("left") is None: 907 # we need to set *some* values, though don't care too much what. 908 parent = cached_field_value(self, opts.parent_attr) 909 # if we have a cached parent, have a stab at getting 910 # possibly-correct values. otherwise, meh. 911 if parent: 912 left = parent._mpttfield("left") + 1 913 setattr(self, opts.left_attr, left) 914 setattr(self, opts.right_attr, left + 1) 915 setattr(self, opts.level_attr, parent._mpttfield("level") + 1) 916 setattr(self, opts.tree_id_attr, parent._mpttfield("tree_id")) 917 self._tree_manager._post_insert_update_cached_parent_right( 918 parent, 2 919 ) 920 else: 921 setattr(self, opts.left_attr, 1) 922 setattr(self, opts.right_attr, 2) 923 setattr(self, opts.level_attr, 0) 924 setattr(self, opts.tree_id_attr, 0) 925 return super().save(*args, **kwargs) 926 927 parent_id = opts.get_raw_field_value(self, opts.parent_attr) 928 929 # determine whether this instance is already in the db 930 force_update = kwargs.get("force_update", False) 931 force_insert = kwargs.get("force_insert", False) 932 collapse_old_tree = None 933 deferred_fields = self.get_deferred_fields() 934 if force_update or ( 935 not force_insert and self._is_saved(using=kwargs.get("using")) 936 ): 937 # it already exists, so do a move 938 old_parent_id = self._mptt_cached_fields[opts.parent_attr] 939 if old_parent_id is DeferredAttribute: 940 same_order = True 941 else: 942 same_order = old_parent_id == parent_id 943 944 if same_order and len(self._mptt_cached_fields) > 1: 945 for field_name, old_value in self._mptt_cached_fields.items(): 946 if ( 947 old_value is DeferredAttribute 948 and field_name not in deferred_fields 949 ): 950 same_order = False 951 break 952 if old_value != opts.get_raw_field_value(self, field_name): 953 same_order = False 954 break 955 if not do_updates and not same_order: 956 same_order = True 957 self.__class__._mptt_track_tree_modified(self._mpttfield("tree_id")) 958 elif (not do_updates) and not same_order and old_parent_id is None: 959 # the old tree no longer exists, so we need to collapse it. 960 collapse_old_tree = self._mpttfield("tree_id") 961 parent = getattr(self, opts.parent_attr) 962 tree_id = parent._mpttfield("tree_id") 963 left = parent._mpttfield("left") + 1 964 self.__class__._mptt_track_tree_modified(tree_id) 965 setattr(self, opts.tree_id_attr, tree_id) 966 setattr(self, opts.left_attr, left) 967 setattr(self, opts.right_attr, left + 1) 968 setattr(self, opts.level_attr, parent._mpttfield("level") + 1) 969 same_order = True 970 971 if not same_order: 972 parent = getattr(self, opts.parent_attr) 973 opts.set_raw_field_value(self, opts.parent_attr, old_parent_id) 974 try: 975 right_sibling = opts.get_ordered_insertion_target(self, parent) 976 977 if parent_id is not None: 978 # If we aren't already a descendant of the new parent, 979 # we need to update the parent.rght so things like 980 # get_children and get_descendant_count work correctly. 981 # 982 # parent might be None if parent_id was assigned 983 # directly -- then we certainly do not have to update 984 # the cached parent. 985 update_cached_parent = parent and ( 986 getattr(self, opts.tree_id_attr) 987 != getattr(parent, opts.tree_id_attr) 988 or getattr(self, opts.left_attr) # noqa 989 < getattr(parent, opts.left_attr) 990 or getattr(self, opts.right_attr) 991 > getattr(parent, opts.right_attr) 992 ) 993 994 if right_sibling: 995 self._tree_manager._move_node( 996 self, 997 right_sibling, 998 "left", 999 save=False, 1000 refresh_target=False, 1001 ) 1002 else: 1003 # Default movement 1004 if parent_id is None: 1005 root_nodes = self._tree_manager.root_nodes() 1006 try: 1007 rightmost_sibling = root_nodes.exclude( 1008 pk=self.pk 1009 ).order_by("-" + opts.tree_id_attr)[0] 1010 self._tree_manager._move_node( 1011 self, 1012 rightmost_sibling, 1013 "right", 1014 save=False, 1015 refresh_target=False, 1016 ) 1017 except IndexError: 1018 pass 1019 else: 1020 self._tree_manager._move_node( 1021 self, parent, "last-child", save=False 1022 ) 1023 1024 if parent_id is not None and update_cached_parent: 1025 # Update rght of cached parent 1026 right_shift = 2 * (self.get_descendant_count() + 1) 1027 self._tree_manager._post_insert_update_cached_parent_right( 1028 parent, right_shift 1029 ) 1030 finally: 1031 # Make sure the new parent is always 1032 # restored on the way out in case of errors. 1033 opts.set_raw_field_value(self, opts.parent_attr, parent_id) 1034 1035 # If there were no exceptions raised then send a moved signal 1036 node_moved.send( 1037 sender=self.__class__, 1038 instance=self, 1039 target=getattr(self, opts.parent_attr), 1040 ) 1041 else: 1042 opts.set_raw_field_value(self, opts.parent_attr, parent_id) 1043 if not track_updates: 1044 # When not using delayed/disabled updates, 1045 # populate update_fields with user defined model fields. 1046 # This helps preserve tree integrity when saving model on top 1047 # of a modified tree. 1048 if len(args) > 3: 1049 if not args[3]: 1050 args = list(args) 1051 args[3] = self._get_user_field_names() 1052 args = tuple(args) 1053 else: 1054 if not kwargs.get("update_fields", None): 1055 kwargs["update_fields"] = self._get_user_field_names() 1056 1057 else: 1058 # new node, do an insert 1059 if getattr(self, opts.left_attr) and getattr(self, opts.right_attr): 1060 # This node has already been set up for insertion. 1061 pass 1062 else: 1063 parent = getattr(self, opts.parent_attr) 1064 1065 right_sibling = None 1066 # if we're inside delay_mptt_updates, don't do queries to find 1067 # sibling position. instead, do default insertion. correct 1068 # positions will be found during partial rebuild later. 1069 # *unless* this is a root node. (as update tracking doesn't 1070 # handle re-ordering of trees.) 1071 if do_updates or parent is None: 1072 if opts.order_insertion_by: 1073 right_sibling = opts.get_ordered_insertion_target(self, parent) 1074 1075 if right_sibling: 1076 self.insert_at( 1077 right_sibling, 1078 "left", 1079 allow_existing_pk=True, 1080 refresh_target=False, 1081 ) 1082 1083 if parent: 1084 # since we didn't insert into parent, we have to update parent.rght 1085 # here instead of in TreeManager.insert_node() 1086 right_shift = 2 * (self.get_descendant_count() + 1) 1087 self._tree_manager._post_insert_update_cached_parent_right( 1088 parent, right_shift 1089 ) 1090 else: 1091 # Default insertion 1092 self.insert_at( 1093 parent, position="last-child", allow_existing_pk=True 1094 ) 1095 try: 1096 super().save(*args, **kwargs) 1097 finally: 1098 if collapse_old_tree is not None: 1099 self._tree_manager._create_tree_space(collapse_old_tree, -1) 1100 1101 self._mptt_saved = True 1102 opts.update_mptt_cached_fields(self) 1103 1104 save.alters_data = True 1105 1106 def delete(self, *args, **kwargs): 1107 """Calling ``delete`` on a node will delete it as well as its full 1108 subtree, as opposed to reattaching all the subnodes to its parent node. 1109 1110 There are no argument specific to a MPTT model, all the arguments will 1111 be passed directly to the django's ``Model.delete``. 1112 1113 ``delete`` will not return anything.""" 1114 try: 1115 # We have to make sure we use database's mptt values, since they 1116 # could have changed between the moment the instance was retrieved and 1117 # the moment it is deleted. 1118 # This happens for example if you delete several nodes at once from a queryset. 1119 fields_to_refresh = [ 1120 self._mptt_meta.right_attr, 1121 self._mptt_meta.left_attr, 1122 self._mptt_meta.tree_id_attr, 1123 ] 1124 self.refresh_from_db(fields=fields_to_refresh) 1125 except self.__class__.DoesNotExist: 1126 # In case the object was already deleted, we don't want to throw an exception 1127 pass 1128 tree_width = self._mpttfield("right") - self._mpttfield("left") + 1 1129 target_right = self._mpttfield("right") 1130 tree_id = self._mpttfield("tree_id") 1131 self._tree_manager._close_gap(tree_width, target_right, tree_id) 1132 parent = cached_field_value(self, self._mptt_meta.parent_attr) 1133 if parent: 1134 right_shift = -self.get_descendant_count() - 2 1135 self._tree_manager._post_insert_update_cached_parent_right( 1136 parent, right_shift 1137 ) 1138 1139 return super().delete(*args, **kwargs) 1140 1141 delete.alters_data = True 1142 1143 def _mptt_refresh(self): 1144 if not self.pk: 1145 return 1146 manager = type(self)._tree_manager 1147 opts = self._mptt_meta 1148 values = ( 1149 manager.using(self._state.db) 1150 .filter(pk=self.pk) 1151 .values( 1152 opts.left_attr, 1153 opts.right_attr, 1154 opts.level_attr, 1155 opts.tree_id_attr, 1156 )[0] 1157 ) 1158 for k, v in values.items(): 1159 setattr(self, k, v) 1160