1# orm/collections.py
2# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8"""Support for collections of mapped entities.
9
10The collections package supplies the machinery used to inform the ORM of
11collection membership changes.  An instrumentation via decoration approach is
12used, allowing arbitrary types (including built-ins) to be used as entity
13collections without requiring inheritance from a base class.
14
15Instrumentation decoration relays membership change events to the
16:class:`.CollectionAttributeImpl` that is currently managing the collection.
17The decorators observe function call arguments and return values, tracking
18entities entering or leaving the collection.  Two decorator approaches are
19provided.  One is a bundle of generic decorators that map function arguments
20and return values to events::
21
22  from sqlalchemy.orm.collections import collection
23  class MyClass(object):
24      # ...
25
26      @collection.adds(1)
27      def store(self, item):
28          self.data.append(item)
29
30      @collection.removes_return()
31      def pop(self):
32          return self.data.pop()
33
34
35The second approach is a bundle of targeted decorators that wrap appropriate
36append and remove notifiers around the mutation methods present in the
37standard Python ``list``, ``set`` and ``dict`` interfaces.  These could be
38specified in terms of generic decorator recipes, but are instead hand-tooled
39for increased efficiency.  The targeted decorators occasionally implement
40adapter-like behavior, such as mapping bulk-set methods (``extend``,
41``update``, ``__setslice__``, etc.) into the series of atomic mutation events
42that the ORM requires.
43
44The targeted decorators are used internally for automatic instrumentation of
45entity collection classes.  Every collection class goes through a
46transformation process roughly like so:
47
481. If the class is a built-in, substitute a trivial sub-class
492. Is this class already instrumented?
503. Add in generic decorators
514. Sniff out the collection interface through duck-typing
525. Add targeted decoration to any undecorated interface method
53
54This process modifies the class at runtime, decorating methods and adding some
55bookkeeping properties.  This isn't possible (or desirable) for built-in
56classes like ``list``, so trivial sub-classes are substituted to hold
57decoration::
58
59  class InstrumentedList(list):
60      pass
61
62Collection classes can be specified in ``relationship(collection_class=)`` as
63types or a function that returns an instance.  Collection classes are
64inspected and instrumented during the mapper compilation phase.  The
65collection_class callable will be executed once to produce a specimen
66instance, and the type of that specimen will be instrumented.  Functions that
67return built-in types like ``lists`` will be adapted to produce instrumented
68instances.
69
70When extending a known type like ``list``, additional decorations are not
71generally not needed.  Odds are, the extension method will delegate to a
72method that's already instrumented.  For example::
73
74  class QueueIsh(list):
75     def push(self, item):
76         self.append(item)
77     def shift(self):
78         return self.pop(0)
79
80There's no need to decorate these methods.  ``append`` and ``pop`` are already
81instrumented as part of the ``list`` interface.  Decorating them would fire
82duplicate events, which should be avoided.
83
84The targeted decoration tries not to rely on other methods in the underlying
85collection class, but some are unavoidable.  Many depend on 'read' methods
86being present to properly instrument a 'write', for example, ``__setitem__``
87needs ``__getitem__``.  "Bulk" methods like ``update`` and ``extend`` may also
88reimplemented in terms of atomic appends and removes, so the ``extend``
89decoration will actually perform many ``append`` operations and not call the
90underlying method at all.
91
92Tight control over bulk operation and the firing of events is also possible by
93implementing the instrumentation internally in your methods.  The basic
94instrumentation package works under the general assumption that collection
95mutation will not raise unusual exceptions.  If you want to closely
96orchestrate append and remove events with exception management, internal
97instrumentation may be the answer.  Within your method,
98``collection_adapter(self)`` will retrieve an object that you can use for
99explicit control over triggering append and remove events.
100
101The owning object and :class:`.CollectionAttributeImpl` are also reachable
102through the adapter, allowing for some very sophisticated behavior.
103
104"""
105
106import operator
107import weakref
108
109from sqlalchemy.util.compat import inspect_getargspec
110from . import base
111from .. import exc as sa_exc
112from .. import util
113from ..sql import expression
114
115__all__ = [
116    "collection",
117    "collection_adapter",
118    "mapped_collection",
119    "column_mapped_collection",
120    "attribute_mapped_collection",
121]
122
123__instrumentation_mutex = util.threading.Lock()
124
125
126class _PlainColumnGetter(object):
127    """Plain column getter, stores collection of Column objects
128    directly.
129
130    Serializes to a :class:`._SerializableColumnGetterV2`
131    which has more expensive __call__() performance
132    and some rare caveats.
133
134    """
135
136    def __init__(self, cols):
137        self.cols = cols
138        self.composite = len(cols) > 1
139
140    def __reduce__(self):
141        return _SerializableColumnGetterV2._reduce_from_cols(self.cols)
142
143    def _cols(self, mapper):
144        return self.cols
145
146    def __call__(self, value):
147        state = base.instance_state(value)
148        m = base._state_mapper(state)
149
150        key = [
151            m._get_state_attr_by_column(state, state.dict, col)
152            for col in self._cols(m)
153        ]
154
155        if self.composite:
156            return tuple(key)
157        else:
158            return key[0]
159
160
161class _SerializableColumnGetter(object):
162    """Column-based getter used in version 0.7.6 only.
163
164    Remains here for pickle compatibility with 0.7.6.
165
166    """
167
168    def __init__(self, colkeys):
169        self.colkeys = colkeys
170        self.composite = len(colkeys) > 1
171
172    def __reduce__(self):
173        return _SerializableColumnGetter, (self.colkeys,)
174
175    def __call__(self, value):
176        state = base.instance_state(value)
177        m = base._state_mapper(state)
178        key = [
179            m._get_state_attr_by_column(
180                state, state.dict, m.mapped_table.columns[k]
181            )
182            for k in self.colkeys
183        ]
184        if self.composite:
185            return tuple(key)
186        else:
187            return key[0]
188
189
190class _SerializableColumnGetterV2(_PlainColumnGetter):
191    """Updated serializable getter which deals with
192    multi-table mapped classes.
193
194    Two extremely unusual cases are not supported.
195    Mappings which have tables across multiple metadata
196    objects, or which are mapped to non-Table selectables
197    linked across inheriting mappers may fail to function
198    here.
199
200    """
201
202    def __init__(self, colkeys):
203        self.colkeys = colkeys
204        self.composite = len(colkeys) > 1
205
206    def __reduce__(self):
207        return self.__class__, (self.colkeys,)
208
209    @classmethod
210    def _reduce_from_cols(cls, cols):
211        def _table_key(c):
212            if not isinstance(c.table, expression.TableClause):
213                return None
214            else:
215                return c.table.key
216
217        colkeys = [(c.key, _table_key(c)) for c in cols]
218        return _SerializableColumnGetterV2, (colkeys,)
219
220    def _cols(self, mapper):
221        cols = []
222        metadata = getattr(mapper.local_table, "metadata", None)
223        for (ckey, tkey) in self.colkeys:
224            if tkey is None or metadata is None or tkey not in metadata:
225                cols.append(mapper.local_table.c[ckey])
226            else:
227                cols.append(metadata.tables[tkey].c[ckey])
228        return cols
229
230
231def column_mapped_collection(mapping_spec):
232    """A dictionary-based collection type with column-based keying.
233
234    Returns a :class:`.MappedCollection` factory with a keying function
235    generated from mapping_spec, which may be a Column or a sequence
236    of Columns.
237
238    The key value must be immutable for the lifetime of the object.  You
239    can not, for example, map on foreign key values if those key values will
240    change during the session, i.e. from None to a database-assigned integer
241    after a session flush.
242
243    """
244    cols = [
245        expression._only_column_elements(q, "mapping_spec")
246        for q in util.to_list(mapping_spec)
247    ]
248    keyfunc = _PlainColumnGetter(cols)
249    return lambda: MappedCollection(keyfunc)
250
251
252class _SerializableAttrGetter(object):
253    def __init__(self, name):
254        self.name = name
255        self.getter = operator.attrgetter(name)
256
257    def __call__(self, target):
258        return self.getter(target)
259
260    def __reduce__(self):
261        return _SerializableAttrGetter, (self.name,)
262
263
264def attribute_mapped_collection(attr_name):
265    """A dictionary-based collection type with attribute-based keying.
266
267    Returns a :class:`.MappedCollection` factory with a keying based on the
268    'attr_name' attribute of entities in the collection, where ``attr_name``
269    is the string name of the attribute.
270
271    The key value must be immutable for the lifetime of the object.  You
272    can not, for example, map on foreign key values if those key values will
273    change during the session, i.e. from None to a database-assigned integer
274    after a session flush.
275
276    """
277    getter = _SerializableAttrGetter(attr_name)
278    return lambda: MappedCollection(getter)
279
280
281def mapped_collection(keyfunc):
282    """A dictionary-based collection type with arbitrary keying.
283
284    Returns a :class:`.MappedCollection` factory with a keying function
285    generated from keyfunc, a callable that takes an entity and returns a
286    key value.
287
288    The key value must be immutable for the lifetime of the object.  You
289    can not, for example, map on foreign key values if those key values will
290    change during the session, i.e. from None to a database-assigned integer
291    after a session flush.
292
293    """
294    return lambda: MappedCollection(keyfunc)
295
296
297class collection(object):
298    """Decorators for entity collection classes.
299
300    The decorators fall into two groups: annotations and interception recipes.
301
302    The annotating decorators (appender, remover, iterator, linker, converter,
303    internally_instrumented) indicate the method's purpose and take no
304    arguments.  They are not written with parens::
305
306        @collection.appender
307        def append(self, append): ...
308
309    The recipe decorators all require parens, even those that take no
310    arguments::
311
312        @collection.adds('entity')
313        def insert(self, position, entity): ...
314
315        @collection.removes_return()
316        def popitem(self): ...
317
318    """
319
320    # Bundled as a class solely for ease of use: packaging, doc strings,
321    # importability.
322
323    @staticmethod
324    def appender(fn):
325        """Tag the method as the collection appender.
326
327        The appender method is called with one positional argument: the value
328        to append. The method will be automatically decorated with 'adds(1)'
329        if not already decorated::
330
331            @collection.appender
332            def add(self, append): ...
333
334            # or, equivalently
335            @collection.appender
336            @collection.adds(1)
337            def add(self, append): ...
338
339            # for mapping type, an 'append' may kick out a previous value
340            # that occupies that slot.  consider d['a'] = 'foo'- any previous
341            # value in d['a'] is discarded.
342            @collection.appender
343            @collection.replaces(1)
344            def add(self, entity):
345                key = some_key_func(entity)
346                previous = None
347                if key in self:
348                    previous = self[key]
349                self[key] = entity
350                return previous
351
352        If the value to append is not allowed in the collection, you may
353        raise an exception.  Something to remember is that the appender
354        will be called for each object mapped by a database query.  If the
355        database contains rows that violate your collection semantics, you
356        will need to get creative to fix the problem, as access via the
357        collection will not work.
358
359        If the appender method is internally instrumented, you must also
360        receive the keyword argument '_sa_initiator' and ensure its
361        promulgation to collection events.
362
363        """
364        fn._sa_instrument_role = "appender"
365        return fn
366
367    @staticmethod
368    def remover(fn):
369        """Tag the method as the collection remover.
370
371        The remover method is called with one positional argument: the value
372        to remove. The method will be automatically decorated with
373        :meth:`removes_return` if not already decorated::
374
375            @collection.remover
376            def zap(self, entity): ...
377
378            # or, equivalently
379            @collection.remover
380            @collection.removes_return()
381            def zap(self, ): ...
382
383        If the value to remove is not present in the collection, you may
384        raise an exception or return None to ignore the error.
385
386        If the remove method is internally instrumented, you must also
387        receive the keyword argument '_sa_initiator' and ensure its
388        promulgation to collection events.
389
390        """
391        fn._sa_instrument_role = "remover"
392        return fn
393
394    @staticmethod
395    def iterator(fn):
396        """Tag the method as the collection remover.
397
398        The iterator method is called with no arguments.  It is expected to
399        return an iterator over all collection members::
400
401            @collection.iterator
402            def __iter__(self): ...
403
404        """
405        fn._sa_instrument_role = "iterator"
406        return fn
407
408    @staticmethod
409    def internally_instrumented(fn):
410        """Tag the method as instrumented.
411
412        This tag will prevent any decoration from being applied to the
413        method. Use this if you are orchestrating your own calls to
414        :func:`.collection_adapter` in one of the basic SQLAlchemy
415        interface methods, or to prevent an automatic ABC method
416        decoration from wrapping your implementation::
417
418            # normally an 'extend' method on a list-like class would be
419            # automatically intercepted and re-implemented in terms of
420            # SQLAlchemy events and append().  your implementation will
421            # never be called, unless:
422            @collection.internally_instrumented
423            def extend(self, items): ...
424
425        """
426        fn._sa_instrumented = True
427        return fn
428
429    @staticmethod
430    @util.deprecated(
431        "1.0",
432        "The :meth:`.collection.linker` handler is deprecated and will "
433        "be removed in a future release.  Please refer to the "
434        ":meth:`.AttributeEvents.init_collection` "
435        "and :meth:`.AttributeEvents.dispose_collection` event handlers. "
436    )
437    def linker(fn):
438        """Tag the method as a "linked to attribute" event handler.
439
440        This optional event handler will be called when the collection class
441        is linked to or unlinked from the InstrumentedAttribute.  It is
442        invoked immediately after the '_sa_adapter' property is set on
443        the instance.  A single argument is passed: the collection adapter
444        that has been linked, or None if unlinking.
445
446
447        """
448        fn._sa_instrument_role = "linker"
449        return fn
450
451    link = linker
452    """Synonym for :meth:`.collection.linker`.
453
454    .. deprecated:: 1.0 - :meth:`.collection.link` is deprecated and will be
455       removed in a future release.
456
457    """
458
459    @staticmethod
460    def converter(fn):
461        """Tag the method as the collection converter.
462
463        This optional method will be called when a collection is being
464        replaced entirely, as in::
465
466            myobj.acollection = [newvalue1, newvalue2]
467
468        The converter method will receive the object being assigned and should
469        return an iterable of values suitable for use by the ``appender``
470        method.  A converter must not assign values or mutate the collection,
471        its sole job is to adapt the value the user provides into an iterable
472        of values for the ORM's use.
473
474        The default converter implementation will use duck-typing to do the
475        conversion.  A dict-like collection will be convert into an iterable
476        of dictionary values, and other types will simply be iterated::
477
478            @collection.converter
479            def convert(self, other): ...
480
481        If the duck-typing of the object does not match the type of this
482        collection, a TypeError is raised.
483
484        Supply an implementation of this method if you want to expand the
485        range of possible types that can be assigned in bulk or perform
486        validation on the values about to be assigned.
487
488        """
489        fn._sa_instrument_role = "converter"
490        return fn
491
492    @staticmethod
493    def adds(arg):
494        """Mark the method as adding an entity to the collection.
495
496        Adds "add to collection" handling to the method.  The decorator
497        argument indicates which method argument holds the SQLAlchemy-relevant
498        value.  Arguments can be specified positionally (i.e. integer) or by
499        name::
500
501            @collection.adds(1)
502            def push(self, item): ...
503
504            @collection.adds('entity')
505            def do_stuff(self, thing, entity=None): ...
506
507        """
508
509        def decorator(fn):
510            fn._sa_instrument_before = ("fire_append_event", arg)
511            return fn
512
513        return decorator
514
515    @staticmethod
516    def replaces(arg):
517        """Mark the method as replacing an entity in the collection.
518
519        Adds "add to collection" and "remove from collection" handling to
520        the method.  The decorator argument indicates which method argument
521        holds the SQLAlchemy-relevant value to be added, and return value, if
522        any will be considered the value to remove.
523
524        Arguments can be specified positionally (i.e. integer) or by name::
525
526            @collection.replaces(2)
527            def __setitem__(self, index, item): ...
528
529        """
530
531        def decorator(fn):
532            fn._sa_instrument_before = ("fire_append_event", arg)
533            fn._sa_instrument_after = "fire_remove_event"
534            return fn
535
536        return decorator
537
538    @staticmethod
539    def removes(arg):
540        """Mark the method as removing an entity in the collection.
541
542        Adds "remove from collection" handling to the method.  The decorator
543        argument indicates which method argument holds the SQLAlchemy-relevant
544        value to be removed. Arguments can be specified positionally (i.e.
545        integer) or by name::
546
547            @collection.removes(1)
548            def zap(self, item): ...
549
550        For methods where the value to remove is not known at call-time, use
551        collection.removes_return.
552
553        """
554
555        def decorator(fn):
556            fn._sa_instrument_before = ("fire_remove_event", arg)
557            return fn
558
559        return decorator
560
561    @staticmethod
562    def removes_return():
563        """Mark the method as removing an entity in the collection.
564
565        Adds "remove from collection" handling to the method.  The return
566        value of the method, if any, is considered the value to remove.  The
567        method arguments are not inspected::
568
569            @collection.removes_return()
570            def pop(self): ...
571
572        For methods where the value to remove is known at call-time, use
573        collection.remove.
574
575        """
576
577        def decorator(fn):
578            fn._sa_instrument_after = "fire_remove_event"
579            return fn
580
581        return decorator
582
583
584collection_adapter = operator.attrgetter("_sa_adapter")
585"""Fetch the :class:`.CollectionAdapter` for a collection."""
586
587
588class CollectionAdapter(object):
589    """Bridges between the ORM and arbitrary Python collections.
590
591    Proxies base-level collection operations (append, remove, iterate)
592    to the underlying Python collection, and emits add/remove events for
593    entities entering or leaving the collection.
594
595    The ORM uses :class:`.CollectionAdapter` exclusively for interaction with
596    entity collections.
597
598
599    """
600
601    __slots__ = (
602        "attr",
603        "_key",
604        "_data",
605        "owner_state",
606        "_converter",
607        "invalidated",
608    )
609
610    def __init__(self, attr, owner_state, data):
611        self.attr = attr
612        self._key = attr.key
613        self._data = weakref.ref(data)
614        self.owner_state = owner_state
615        data._sa_adapter = self
616        self._converter = data._sa_converter
617        self.invalidated = False
618
619    def _warn_invalidated(self):
620        util.warn("This collection has been invalidated.")
621
622    @property
623    def data(self):
624        "The entity collection being adapted."
625        return self._data()
626
627    @property
628    def _referenced_by_owner(self):
629        """return True if the owner state still refers to this collection.
630
631        This will return False within a bulk replace operation,
632        where this collection is the one being replaced.
633
634        """
635        return self.owner_state.dict[self._key] is self._data()
636
637    def bulk_appender(self):
638        return self._data()._sa_appender
639
640    def append_with_event(self, item, initiator=None):
641        """Add an entity to the collection, firing mutation events."""
642
643        self._data()._sa_appender(item, _sa_initiator=initiator)
644
645    def append_without_event(self, item):
646        """Add or restore an entity to the collection, firing no events."""
647        self._data()._sa_appender(item, _sa_initiator=False)
648
649    def append_multiple_without_event(self, items):
650        """Add or restore an entity to the collection, firing no events."""
651        appender = self._data()._sa_appender
652        for item in items:
653            appender(item, _sa_initiator=False)
654
655    def bulk_remover(self):
656        return self._data()._sa_remover
657
658    def remove_with_event(self, item, initiator=None):
659        """Remove an entity from the collection, firing mutation events."""
660        self._data()._sa_remover(item, _sa_initiator=initiator)
661
662    def remove_without_event(self, item):
663        """Remove an entity from the collection, firing no events."""
664        self._data()._sa_remover(item, _sa_initiator=False)
665
666    def clear_with_event(self, initiator=None):
667        """Empty the collection, firing a mutation event for each entity."""
668
669        remover = self._data()._sa_remover
670        for item in list(self):
671            remover(item, _sa_initiator=initiator)
672
673    def clear_without_event(self):
674        """Empty the collection, firing no events."""
675
676        remover = self._data()._sa_remover
677        for item in list(self):
678            remover(item, _sa_initiator=False)
679
680    def __iter__(self):
681        """Iterate over entities in the collection."""
682
683        return iter(self._data()._sa_iterator())
684
685    def __len__(self):
686        """Count entities in the collection."""
687        return len(list(self._data()._sa_iterator()))
688
689    def __bool__(self):
690        return True
691
692    __nonzero__ = __bool__
693
694    def fire_append_event(self, item, initiator=None):
695        """Notify that a entity has entered the collection.
696
697        Initiator is a token owned by the InstrumentedAttribute that
698        initiated the membership mutation, and should be left as None
699        unless you are passing along an initiator value from a chained
700        operation.
701
702        """
703        if initiator is not False:
704            if self.invalidated:
705                self._warn_invalidated()
706            return self.attr.fire_append_event(
707                self.owner_state, self.owner_state.dict, item, initiator
708            )
709        else:
710            return item
711
712    def fire_remove_event(self, item, initiator=None):
713        """Notify that a entity has been removed from the collection.
714
715        Initiator is the InstrumentedAttribute that initiated the membership
716        mutation, and should be left as None unless you are passing along
717        an initiator value from a chained operation.
718
719        """
720        if initiator is not False:
721            if self.invalidated:
722                self._warn_invalidated()
723            self.attr.fire_remove_event(
724                self.owner_state, self.owner_state.dict, item, initiator
725            )
726
727    def fire_pre_remove_event(self, initiator=None):
728        """Notify that an entity is about to be removed from the collection.
729
730        Only called if the entity cannot be removed after calling
731        fire_remove_event().
732
733        """
734        if self.invalidated:
735            self._warn_invalidated()
736        self.attr.fire_pre_remove_event(
737            self.owner_state, self.owner_state.dict, initiator=initiator
738        )
739
740    def __getstate__(self):
741        return {
742            "key": self._key,
743            "owner_state": self.owner_state,
744            "owner_cls": self.owner_state.class_,
745            "data": self.data,
746            "invalidated": self.invalidated,
747        }
748
749    def __setstate__(self, d):
750        self._key = d["key"]
751        self.owner_state = d["owner_state"]
752        self._data = weakref.ref(d["data"])
753        self._converter = d["data"]._sa_converter
754        d["data"]._sa_adapter = self
755        self.invalidated = d["invalidated"]
756        self.attr = getattr(d["owner_cls"], self._key).impl
757
758
759def bulk_replace(values, existing_adapter, new_adapter, initiator=None):
760    """Load a new collection, firing events based on prior like membership.
761
762    Appends instances in ``values`` onto the ``new_adapter``. Events will be
763    fired for any instance not present in the ``existing_adapter``.  Any
764    instances in ``existing_adapter`` not present in ``values`` will have
765    remove events fired upon them.
766
767    :param values: An iterable of collection member instances
768
769    :param existing_adapter: A :class:`.CollectionAdapter` of
770     instances to be replaced
771
772    :param new_adapter: An empty :class:`.CollectionAdapter`
773     to load with ``values``
774
775
776    """
777
778    assert isinstance(values, list)
779
780    idset = util.IdentitySet
781    existing_idset = idset(existing_adapter or ())
782    constants = existing_idset.intersection(values or ())
783    additions = idset(values or ()).difference(constants)
784    removals = existing_idset.difference(constants)
785
786    appender = new_adapter.bulk_appender()
787
788    for member in values or ():
789        if member in additions:
790            appender(member, _sa_initiator=initiator)
791        elif member in constants:
792            appender(member, _sa_initiator=False)
793
794    if existing_adapter:
795        for member in removals:
796            existing_adapter.fire_remove_event(member, initiator=initiator)
797
798
799def prepare_instrumentation(factory):
800    """Prepare a callable for future use as a collection class factory.
801
802    Given a collection class factory (either a type or no-arg callable),
803    return another factory that will produce compatible instances when
804    called.
805
806    This function is responsible for converting collection_class=list
807    into the run-time behavior of collection_class=InstrumentedList.
808
809    """
810    # Convert a builtin to 'Instrumented*'
811    if factory in __canned_instrumentation:
812        factory = __canned_instrumentation[factory]
813
814    # Create a specimen
815    cls = type(factory())
816
817    # Did factory callable return a builtin?
818    if cls in __canned_instrumentation:
819        # Wrap it so that it returns our 'Instrumented*'
820        factory = __converting_factory(cls, factory)
821        cls = factory()
822
823    # Instrument the class if needed.
824    if __instrumentation_mutex.acquire():
825        try:
826            if getattr(cls, "_sa_instrumented", None) != id(cls):
827                _instrument_class(cls)
828        finally:
829            __instrumentation_mutex.release()
830
831    return factory
832
833
834def __converting_factory(specimen_cls, original_factory):
835    """Return a wrapper that converts a "canned" collection like
836    set, dict, list into the Instrumented* version.
837
838    """
839
840    instrumented_cls = __canned_instrumentation[specimen_cls]
841
842    def wrapper():
843        collection = original_factory()
844        return instrumented_cls(collection)
845
846    # often flawed but better than nothing
847    wrapper.__name__ = "%sWrapper" % original_factory.__name__
848    wrapper.__doc__ = original_factory.__doc__
849
850    return wrapper
851
852
853def _instrument_class(cls):
854    """Modify methods in a class and install instrumentation."""
855
856    # In the normal call flow, a request for any of the 3 basic collection
857    # types is transformed into one of our trivial subclasses
858    # (e.g. InstrumentedList).  Catch anything else that sneaks in here...
859    if cls.__module__ == "__builtin__":
860        raise sa_exc.ArgumentError(
861            "Can not instrument a built-in type. Use a "
862            "subclass, even a trivial one."
863        )
864
865    roles, methods = _locate_roles_and_methods(cls)
866
867    _setup_canned_roles(cls, roles, methods)
868
869    _assert_required_roles(cls, roles, methods)
870
871    _set_collection_attributes(cls, roles, methods)
872
873
874def _locate_roles_and_methods(cls):
875    """search for _sa_instrument_role-decorated methods in
876    method resolution order, assign to roles.
877
878    """
879
880    roles = {}
881    methods = {}
882
883    for supercls in cls.__mro__:
884        for name, method in vars(supercls).items():
885            if not util.callable(method):
886                continue
887
888            # note role declarations
889            if hasattr(method, "_sa_instrument_role"):
890                role = method._sa_instrument_role
891                assert role in (
892                    "appender",
893                    "remover",
894                    "iterator",
895                    "linker",
896                    "converter",
897                )
898                roles.setdefault(role, name)
899
900            # transfer instrumentation requests from decorated function
901            # to the combined queue
902            before, after = None, None
903            if hasattr(method, "_sa_instrument_before"):
904                op, argument = method._sa_instrument_before
905                assert op in ("fire_append_event", "fire_remove_event")
906                before = op, argument
907            if hasattr(method, "_sa_instrument_after"):
908                op = method._sa_instrument_after
909                assert op in ("fire_append_event", "fire_remove_event")
910                after = op
911            if before:
912                methods[name] = before + (after,)
913            elif after:
914                methods[name] = None, None, after
915    return roles, methods
916
917
918def _setup_canned_roles(cls, roles, methods):
919    """see if this class has "canned" roles based on a known
920    collection type (dict, set, list).  Apply those roles
921    as needed to the "roles" dictionary, and also
922    prepare "decorator" methods
923
924    """
925    collection_type = util.duck_type_collection(cls)
926    if collection_type in __interfaces:
927        canned_roles, decorators = __interfaces[collection_type]
928        for role, name in canned_roles.items():
929            roles.setdefault(role, name)
930
931        # apply ABC auto-decoration to methods that need it
932        for method, decorator in decorators.items():
933            fn = getattr(cls, method, None)
934            if (
935                fn
936                and method not in methods
937                and not hasattr(fn, "_sa_instrumented")
938            ):
939                setattr(cls, method, decorator(fn))
940
941
942def _assert_required_roles(cls, roles, methods):
943    """ensure all roles are present, and apply implicit instrumentation if
944    needed
945
946    """
947    if "appender" not in roles or not hasattr(cls, roles["appender"]):
948        raise sa_exc.ArgumentError(
949            "Type %s must elect an appender method to be "
950            "a collection class" % cls.__name__
951        )
952    elif roles["appender"] not in methods and not hasattr(
953        getattr(cls, roles["appender"]), "_sa_instrumented"
954    ):
955        methods[roles["appender"]] = ("fire_append_event", 1, None)
956
957    if "remover" not in roles or not hasattr(cls, roles["remover"]):
958        raise sa_exc.ArgumentError(
959            "Type %s must elect a remover method to be "
960            "a collection class" % cls.__name__
961        )
962    elif roles["remover"] not in methods and not hasattr(
963        getattr(cls, roles["remover"]), "_sa_instrumented"
964    ):
965        methods[roles["remover"]] = ("fire_remove_event", 1, None)
966
967    if "iterator" not in roles or not hasattr(cls, roles["iterator"]):
968        raise sa_exc.ArgumentError(
969            "Type %s must elect an iterator method to be "
970            "a collection class" % cls.__name__
971        )
972
973
974def _set_collection_attributes(cls, roles, methods):
975    """apply ad-hoc instrumentation from decorators, class-level defaults
976    and implicit role declarations
977
978    """
979    for method_name, (before, argument, after) in methods.items():
980        setattr(
981            cls,
982            method_name,
983            _instrument_membership_mutator(
984                getattr(cls, method_name), before, argument, after
985            ),
986        )
987    # intern the role map
988    for role, method_name in roles.items():
989        setattr(cls, "_sa_%s" % role, getattr(cls, method_name))
990
991    cls._sa_adapter = None
992
993    if not hasattr(cls, "_sa_converter"):
994        cls._sa_converter = None
995    cls._sa_instrumented = id(cls)
996
997
998def _instrument_membership_mutator(method, before, argument, after):
999    """Route method args and/or return value through the collection
1000    adapter."""
1001    # This isn't smart enough to handle @adds(1) for 'def fn(self, (a, b))'
1002    if before:
1003        fn_args = list(util.flatten_iterator(inspect_getargspec(method)[0]))
1004        if isinstance(argument, int):
1005            pos_arg = argument
1006            named_arg = len(fn_args) > argument and fn_args[argument] or None
1007        else:
1008            if argument in fn_args:
1009                pos_arg = fn_args.index(argument)
1010            else:
1011                pos_arg = None
1012            named_arg = argument
1013        del fn_args
1014
1015    def wrapper(*args, **kw):
1016        if before:
1017            if pos_arg is None:
1018                if named_arg not in kw:
1019                    raise sa_exc.ArgumentError(
1020                        "Missing argument %s" % argument
1021                    )
1022                value = kw[named_arg]
1023            else:
1024                if len(args) > pos_arg:
1025                    value = args[pos_arg]
1026                elif named_arg in kw:
1027                    value = kw[named_arg]
1028                else:
1029                    raise sa_exc.ArgumentError(
1030                        "Missing argument %s" % argument
1031                    )
1032
1033        initiator = kw.pop("_sa_initiator", None)
1034        if initiator is False:
1035            executor = None
1036        else:
1037            executor = args[0]._sa_adapter
1038
1039        if before and executor:
1040            getattr(executor, before)(value, initiator)
1041
1042        if not after or not executor:
1043            return method(*args, **kw)
1044        else:
1045            res = method(*args, **kw)
1046            if res is not None:
1047                getattr(executor, after)(res, initiator)
1048            return res
1049
1050    wrapper._sa_instrumented = True
1051    if hasattr(method, "_sa_instrument_role"):
1052        wrapper._sa_instrument_role = method._sa_instrument_role
1053    wrapper.__name__ = method.__name__
1054    wrapper.__doc__ = method.__doc__
1055    return wrapper
1056
1057
1058def __set(collection, item, _sa_initiator=None):
1059    """Run set events, may eventually be inlined into decorators."""
1060
1061    if _sa_initiator is not False:
1062        executor = collection._sa_adapter
1063        if executor:
1064            item = executor.fire_append_event(item, _sa_initiator)
1065    return item
1066
1067
1068def __del(collection, item, _sa_initiator=None):
1069    """Run del events, may eventually be inlined into decorators."""
1070    if _sa_initiator is not False:
1071        executor = collection._sa_adapter
1072        if executor:
1073            executor.fire_remove_event(item, _sa_initiator)
1074
1075
1076def __before_delete(collection, _sa_initiator=None):
1077    """Special method to run 'commit existing value' methods"""
1078    executor = collection._sa_adapter
1079    if executor:
1080        executor.fire_pre_remove_event(_sa_initiator)
1081
1082
1083def _list_decorators():
1084    """Tailored instrumentation wrappers for any list-like class."""
1085
1086    def _tidy(fn):
1087        fn._sa_instrumented = True
1088        fn.__doc__ = getattr(list, fn.__name__).__doc__
1089
1090    def append(fn):
1091        def append(self, item, _sa_initiator=None):
1092            item = __set(self, item, _sa_initiator)
1093            fn(self, item)
1094
1095        _tidy(append)
1096        return append
1097
1098    def remove(fn):
1099        def remove(self, value, _sa_initiator=None):
1100            __before_delete(self, _sa_initiator)
1101            # testlib.pragma exempt:__eq__
1102            fn(self, value)
1103            __del(self, value, _sa_initiator)
1104
1105        _tidy(remove)
1106        return remove
1107
1108    def insert(fn):
1109        def insert(self, index, value):
1110            value = __set(self, value)
1111            fn(self, index, value)
1112
1113        _tidy(insert)
1114        return insert
1115
1116    def __setitem__(fn):
1117        def __setitem__(self, index, value):
1118            if not isinstance(index, slice):
1119                existing = self[index]
1120                if existing is not None:
1121                    __del(self, existing)
1122                value = __set(self, value)
1123                fn(self, index, value)
1124            else:
1125                # slice assignment requires __delitem__, insert, __len__
1126                step = index.step or 1
1127                start = index.start or 0
1128                if start < 0:
1129                    start += len(self)
1130                if index.stop is not None:
1131                    stop = index.stop
1132                else:
1133                    stop = len(self)
1134                if stop < 0:
1135                    stop += len(self)
1136
1137                if step == 1:
1138                    for i in range(start, stop, step):
1139                        if len(self) > start:
1140                            del self[start]
1141
1142                    for i, item in enumerate(value):
1143                        self.insert(i + start, item)
1144                else:
1145                    rng = list(range(start, stop, step))
1146                    if len(value) != len(rng):
1147                        raise ValueError(
1148                            "attempt to assign sequence of size %s to "
1149                            "extended slice of size %s"
1150                            % (len(value), len(rng))
1151                        )
1152                    for i, item in zip(rng, value):
1153                        self.__setitem__(i, item)
1154
1155        _tidy(__setitem__)
1156        return __setitem__
1157
1158    def __delitem__(fn):
1159        def __delitem__(self, index):
1160            if not isinstance(index, slice):
1161                item = self[index]
1162                __del(self, item)
1163                fn(self, index)
1164            else:
1165                # slice deletion requires __getslice__ and a slice-groking
1166                # __getitem__ for stepped deletion
1167                # note: not breaking this into atomic dels
1168                for item in self[index]:
1169                    __del(self, item)
1170                fn(self, index)
1171
1172        _tidy(__delitem__)
1173        return __delitem__
1174
1175    if util.py2k:
1176
1177        def __setslice__(fn):
1178            def __setslice__(self, start, end, values):
1179                for value in self[start:end]:
1180                    __del(self, value)
1181                values = [__set(self, value) for value in values]
1182                fn(self, start, end, values)
1183
1184            _tidy(__setslice__)
1185            return __setslice__
1186
1187        def __delslice__(fn):
1188            def __delslice__(self, start, end):
1189                for value in self[start:end]:
1190                    __del(self, value)
1191                fn(self, start, end)
1192
1193            _tidy(__delslice__)
1194            return __delslice__
1195
1196    def extend(fn):
1197        def extend(self, iterable):
1198            for value in iterable:
1199                self.append(value)
1200
1201        _tidy(extend)
1202        return extend
1203
1204    def __iadd__(fn):
1205        def __iadd__(self, iterable):
1206            # list.__iadd__ takes any iterable and seems to let TypeError
1207            # raise as-is instead of returning NotImplemented
1208            for value in iterable:
1209                self.append(value)
1210            return self
1211
1212        _tidy(__iadd__)
1213        return __iadd__
1214
1215    def pop(fn):
1216        def pop(self, index=-1):
1217            __before_delete(self)
1218            item = fn(self, index)
1219            __del(self, item)
1220            return item
1221
1222        _tidy(pop)
1223        return pop
1224
1225    if not util.py2k:
1226
1227        def clear(fn):
1228            def clear(self, index=-1):
1229                for item in self:
1230                    __del(self, item)
1231                fn(self)
1232
1233            _tidy(clear)
1234            return clear
1235
1236    # __imul__ : not wrapping this.  all members of the collection are already
1237    # present, so no need to fire appends... wrapping it with an explicit
1238    # decorator is still possible, so events on *= can be had if they're
1239    # desired.  hard to imagine a use case for __imul__, though.
1240
1241    l = locals().copy()
1242    l.pop("_tidy")
1243    return l
1244
1245
1246def _dict_decorators():
1247    """Tailored instrumentation wrappers for any dict-like mapping class."""
1248
1249    def _tidy(fn):
1250        fn._sa_instrumented = True
1251        fn.__doc__ = getattr(dict, fn.__name__).__doc__
1252
1253    Unspecified = util.symbol("Unspecified")
1254
1255    def __setitem__(fn):
1256        def __setitem__(self, key, value, _sa_initiator=None):
1257            if key in self:
1258                __del(self, self[key], _sa_initiator)
1259            value = __set(self, value, _sa_initiator)
1260            fn(self, key, value)
1261
1262        _tidy(__setitem__)
1263        return __setitem__
1264
1265    def __delitem__(fn):
1266        def __delitem__(self, key, _sa_initiator=None):
1267            if key in self:
1268                __del(self, self[key], _sa_initiator)
1269            fn(self, key)
1270
1271        _tidy(__delitem__)
1272        return __delitem__
1273
1274    def clear(fn):
1275        def clear(self):
1276            for key in self:
1277                __del(self, self[key])
1278            fn(self)
1279
1280        _tidy(clear)
1281        return clear
1282
1283    def pop(fn):
1284        def pop(self, key, default=Unspecified):
1285            if key in self:
1286                __del(self, self[key])
1287            if default is Unspecified:
1288                return fn(self, key)
1289            else:
1290                return fn(self, key, default)
1291
1292        _tidy(pop)
1293        return pop
1294
1295    def popitem(fn):
1296        def popitem(self):
1297            __before_delete(self)
1298            item = fn(self)
1299            __del(self, item[1])
1300            return item
1301
1302        _tidy(popitem)
1303        return popitem
1304
1305    def setdefault(fn):
1306        def setdefault(self, key, default=None):
1307            if key not in self:
1308                self.__setitem__(key, default)
1309                return default
1310            else:
1311                return self.__getitem__(key)
1312
1313        _tidy(setdefault)
1314        return setdefault
1315
1316    def update(fn):
1317        def update(self, __other=Unspecified, **kw):
1318            if __other is not Unspecified:
1319                if hasattr(__other, "keys"):
1320                    for key in list(__other):
1321                        if key not in self or self[key] is not __other[key]:
1322                            self[key] = __other[key]
1323                else:
1324                    for key, value in __other:
1325                        if key not in self or self[key] is not value:
1326                            self[key] = value
1327            for key in kw:
1328                if key not in self or self[key] is not kw[key]:
1329                    self[key] = kw[key]
1330
1331        _tidy(update)
1332        return update
1333
1334    l = locals().copy()
1335    l.pop("_tidy")
1336    l.pop("Unspecified")
1337    return l
1338
1339
1340_set_binop_bases = (set, frozenset)
1341
1342
1343def _set_binops_check_strict(self, obj):
1344    """Allow only set, frozenset and self.__class__-derived
1345    objects in binops."""
1346    return isinstance(obj, _set_binop_bases + (self.__class__,))
1347
1348
1349def _set_binops_check_loose(self, obj):
1350    """Allow anything set-like to participate in set binops."""
1351    return (
1352        isinstance(obj, _set_binop_bases + (self.__class__,))
1353        or util.duck_type_collection(obj) == set
1354    )
1355
1356
1357def _set_decorators():
1358    """Tailored instrumentation wrappers for any set-like class."""
1359
1360    def _tidy(fn):
1361        fn._sa_instrumented = True
1362        fn.__doc__ = getattr(set, fn.__name__).__doc__
1363
1364    Unspecified = util.symbol("Unspecified")
1365
1366    def add(fn):
1367        def add(self, value, _sa_initiator=None):
1368            if value not in self:
1369                value = __set(self, value, _sa_initiator)
1370            # testlib.pragma exempt:__hash__
1371            fn(self, value)
1372
1373        _tidy(add)
1374        return add
1375
1376    def discard(fn):
1377        def discard(self, value, _sa_initiator=None):
1378            # testlib.pragma exempt:__hash__
1379            if value in self:
1380                __del(self, value, _sa_initiator)
1381                # testlib.pragma exempt:__hash__
1382            fn(self, value)
1383
1384        _tidy(discard)
1385        return discard
1386
1387    def remove(fn):
1388        def remove(self, value, _sa_initiator=None):
1389            # testlib.pragma exempt:__hash__
1390            if value in self:
1391                __del(self, value, _sa_initiator)
1392            # testlib.pragma exempt:__hash__
1393            fn(self, value)
1394
1395        _tidy(remove)
1396        return remove
1397
1398    def pop(fn):
1399        def pop(self):
1400            __before_delete(self)
1401            item = fn(self)
1402            __del(self, item)
1403            return item
1404
1405        _tidy(pop)
1406        return pop
1407
1408    def clear(fn):
1409        def clear(self):
1410            for item in list(self):
1411                self.remove(item)
1412
1413        _tidy(clear)
1414        return clear
1415
1416    def update(fn):
1417        def update(self, value):
1418            for item in value:
1419                self.add(item)
1420
1421        _tidy(update)
1422        return update
1423
1424    def __ior__(fn):
1425        def __ior__(self, value):
1426            if not _set_binops_check_strict(self, value):
1427                return NotImplemented
1428            for item in value:
1429                self.add(item)
1430            return self
1431
1432        _tidy(__ior__)
1433        return __ior__
1434
1435    def difference_update(fn):
1436        def difference_update(self, value):
1437            for item in value:
1438                self.discard(item)
1439
1440        _tidy(difference_update)
1441        return difference_update
1442
1443    def __isub__(fn):
1444        def __isub__(self, value):
1445            if not _set_binops_check_strict(self, value):
1446                return NotImplemented
1447            for item in value:
1448                self.discard(item)
1449            return self
1450
1451        _tidy(__isub__)
1452        return __isub__
1453
1454    def intersection_update(fn):
1455        def intersection_update(self, other):
1456            want, have = self.intersection(other), set(self)
1457            remove, add = have - want, want - have
1458
1459            for item in remove:
1460                self.remove(item)
1461            for item in add:
1462                self.add(item)
1463
1464        _tidy(intersection_update)
1465        return intersection_update
1466
1467    def __iand__(fn):
1468        def __iand__(self, other):
1469            if not _set_binops_check_strict(self, other):
1470                return NotImplemented
1471            want, have = self.intersection(other), set(self)
1472            remove, add = have - want, want - have
1473
1474            for item in remove:
1475                self.remove(item)
1476            for item in add:
1477                self.add(item)
1478            return self
1479
1480        _tidy(__iand__)
1481        return __iand__
1482
1483    def symmetric_difference_update(fn):
1484        def symmetric_difference_update(self, other):
1485            want, have = self.symmetric_difference(other), set(self)
1486            remove, add = have - want, want - have
1487
1488            for item in remove:
1489                self.remove(item)
1490            for item in add:
1491                self.add(item)
1492
1493        _tidy(symmetric_difference_update)
1494        return symmetric_difference_update
1495
1496    def __ixor__(fn):
1497        def __ixor__(self, other):
1498            if not _set_binops_check_strict(self, other):
1499                return NotImplemented
1500            want, have = self.symmetric_difference(other), set(self)
1501            remove, add = have - want, want - have
1502
1503            for item in remove:
1504                self.remove(item)
1505            for item in add:
1506                self.add(item)
1507            return self
1508
1509        _tidy(__ixor__)
1510        return __ixor__
1511
1512    l = locals().copy()
1513    l.pop("_tidy")
1514    l.pop("Unspecified")
1515    return l
1516
1517
1518class InstrumentedList(list):
1519    """An instrumented version of the built-in list."""
1520
1521
1522class InstrumentedSet(set):
1523    """An instrumented version of the built-in set."""
1524
1525
1526class InstrumentedDict(dict):
1527    """An instrumented version of the built-in dict."""
1528
1529
1530__canned_instrumentation = {
1531    list: InstrumentedList,
1532    set: InstrumentedSet,
1533    dict: InstrumentedDict,
1534}
1535
1536__interfaces = {
1537    list: (
1538        {"appender": "append", "remover": "remove", "iterator": "__iter__"},
1539        _list_decorators(),
1540    ),
1541    set: (
1542        {"appender": "add", "remover": "remove", "iterator": "__iter__"},
1543        _set_decorators(),
1544    ),
1545    # decorators are required for dicts and object collections.
1546    dict: ({"iterator": "values"}, _dict_decorators())
1547    if util.py3k
1548    else ({"iterator": "itervalues"}, _dict_decorators()),
1549}
1550
1551
1552class MappedCollection(dict):
1553    """A basic dictionary-based collection class.
1554
1555    Extends dict with the minimal bag semantics that collection
1556    classes require. ``set`` and ``remove`` are implemented in terms
1557    of a keying function: any callable that takes an object and
1558    returns an object for use as a dictionary key.
1559
1560    """
1561
1562    def __init__(self, keyfunc):
1563        """Create a new collection with keying provided by keyfunc.
1564
1565        keyfunc may be any callable that takes an object and returns an object
1566        for use as a dictionary key.
1567
1568        The keyfunc will be called every time the ORM needs to add a member by
1569        value-only (such as when loading instances from the database) or
1570        remove a member.  The usual cautions about dictionary keying apply-
1571        ``keyfunc(object)`` should return the same output for the life of the
1572        collection.  Keying based on mutable properties can result in
1573        unreachable instances "lost" in the collection.
1574
1575        """
1576        self.keyfunc = keyfunc
1577
1578    @collection.appender
1579    @collection.internally_instrumented
1580    def set(self, value, _sa_initiator=None):
1581        """Add an item by value, consulting the keyfunc for the key."""
1582
1583        key = self.keyfunc(value)
1584        self.__setitem__(key, value, _sa_initiator)
1585
1586    @collection.remover
1587    @collection.internally_instrumented
1588    def remove(self, value, _sa_initiator=None):
1589        """Remove an item by value, consulting the keyfunc for the key."""
1590
1591        key = self.keyfunc(value)
1592        # Let self[key] raise if key is not in this collection
1593        # testlib.pragma exempt:__ne__
1594        if self[key] != value:
1595            raise sa_exc.InvalidRequestError(
1596                "Can not remove '%s': collection holds '%s' for key '%s'. "
1597                "Possible cause: is the MappedCollection key function "
1598                "based on mutable properties or properties that only obtain "
1599                "values after flush?" % (value, self[key], key)
1600            )
1601        self.__delitem__(key, _sa_initiator)
1602
1603    @collection.converter
1604    def _convert(self, dictlike):
1605        """Validate and convert a dict-like object into values for set()ing.
1606
1607        This is called behind the scenes when a MappedCollection is replaced
1608        entirely by another collection, as in::
1609
1610          myobj.mappedcollection = {'a':obj1, 'b': obj2} # ...
1611
1612        Raises a TypeError if the key in any (key, value) pair in the dictlike
1613        object does not match the key that this collection's keyfunc would
1614        have assigned for that value.
1615
1616        """
1617        for incoming_key, value in util.dictlike_iteritems(dictlike):
1618            new_key = self.keyfunc(value)
1619            if incoming_key != new_key:
1620                raise TypeError(
1621                    "Found incompatible key %r for value %r; this "
1622                    "collection's "
1623                    "keying function requires a key of %r for this value."
1624                    % (incoming_key, value, new_key)
1625                )
1626            yield value
1627
1628
1629# ensure instrumentation is associated with
1630# these built-in classes; if a user-defined class
1631# subclasses these and uses @internally_instrumented,
1632# the superclass is otherwise not instrumented.
1633# see [ticket:2406].
1634_instrument_class(MappedCollection)
1635_instrument_class(InstrumentedList)
1636_instrument_class(InstrumentedSet)
1637