1import contextlib
2from operator import and_
3
4from sqlalchemy import event
5from sqlalchemy import exc as sa_exc
6from sqlalchemy import ForeignKey
7from sqlalchemy import Integer
8from sqlalchemy import String
9from sqlalchemy import testing
10from sqlalchemy import text
11from sqlalchemy import util
12from sqlalchemy.orm import attributes
13from sqlalchemy.orm import declarative_base
14from sqlalchemy.orm import instrumentation
15from sqlalchemy.orm import relationship
16import sqlalchemy.orm.collections as collections
17from sqlalchemy.orm.collections import collection
18from sqlalchemy.testing import assert_raises
19from sqlalchemy.testing import assert_raises_message
20from sqlalchemy.testing import eq_
21from sqlalchemy.testing import fixtures
22from sqlalchemy.testing import is_false
23from sqlalchemy.testing import is_true
24from sqlalchemy.testing import ne_
25from sqlalchemy.testing.fixtures import fixture_session
26from sqlalchemy.testing.schema import Column
27from sqlalchemy.testing.schema import Table
28
29
30class Canary(object):
31    def __init__(self):
32        self.data = set()
33        self.added = set()
34        self.removed = set()
35        self.appended_wo_mutation = set()
36        self.dupe_check = True
37
38    @contextlib.contextmanager
39    def defer_dupe_check(self):
40        self.dupe_check = False
41        try:
42            yield
43        finally:
44            self.dupe_check = True
45
46    def listen(self, attr):
47        event.listen(attr, "append", self.append)
48        event.listen(attr, "append_wo_mutation", self.append_wo_mutation)
49        event.listen(attr, "remove", self.remove)
50        event.listen(attr, "set", self.set)
51
52    def append(self, obj, value, initiator):
53        if self.dupe_check:
54            assert value not in self.added
55            self.added.add(value)
56        self.data.add(value)
57        return value
58
59    def append_wo_mutation(self, obj, value, initiator):
60        if self.dupe_check:
61            assert value in self.added
62            self.appended_wo_mutation.add(value)
63
64    def remove(self, obj, value, initiator):
65        if self.dupe_check:
66            assert value not in self.removed
67            self.removed.add(value)
68        self.data.remove(value)
69
70    def set(self, obj, value, oldvalue, initiator):
71        if isinstance(value, str):
72            value = CollectionsTest.entity_maker()
73
74        if oldvalue is not None:
75            self.remove(obj, oldvalue, None)
76        self.append(obj, value, None)
77        return value
78
79
80class OrderedDictFixture(object):
81    @testing.fixture
82    def ordered_dict_mro(self):
83        if testing.requires.python37.enabled:
84            return type("ordered", (collections.MappedCollection,), {})
85        else:
86            return type(
87                "ordered", (util.OrderedDict, collections.MappedCollection), {}
88            )
89
90
91class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
92    class Entity(object):
93        def __init__(self, a=None, b=None, c=None):
94            self.a = a
95            self.b = b
96            self.c = c
97
98        def __repr__(self):
99            return str((id(self), self.a, self.b, self.c))
100
101    @classmethod
102    def setup_test_class(cls):
103        instrumentation.register_class(cls.Entity)
104
105    @classmethod
106    def teardown_test_class(cls):
107        instrumentation.unregister_class(cls.Entity)
108
109    _entity_id = 1
110
111    @classmethod
112    def entity_maker(cls):
113        cls._entity_id += 1
114        return cls.Entity(cls._entity_id)
115
116    @classmethod
117    def dictable_entity(cls, a=None, b=None, c=None):
118        id_ = cls._entity_id = cls._entity_id + 1
119        return cls.Entity(a or str(id_), b or "value %s" % id, c)
120
121    def _test_adapter(self, typecallable, creator=None, to_set=None):
122        if creator is None:
123            creator = self.entity_maker
124
125        class Foo(object):
126            pass
127
128        canary = Canary()
129        instrumentation.register_class(Foo)
130        d = attributes.register_attribute(
131            Foo,
132            "attr",
133            uselist=True,
134            typecallable=typecallable,
135            useobject=True,
136        )
137        canary.listen(d)
138
139        obj = Foo()
140        adapter = collections.collection_adapter(obj.attr)
141        direct = obj.attr
142        if to_set is None:
143
144            def to_set(col):
145                return set(col)
146
147        def assert_eq():
148            self.assert_(to_set(direct) == canary.data)
149            self.assert_(set(adapter) == canary.data)
150
151        def assert_ne():
152            self.assert_(to_set(direct) != canary.data)
153
154        e1, e2 = creator(), creator()
155
156        adapter.append_with_event(e1)
157        assert_eq()
158
159        adapter.append_without_event(e2)
160        assert_ne()
161        canary.data.add(e2)
162        assert_eq()
163
164        adapter.remove_without_event(e2)
165        assert_ne()
166        canary.data.remove(e2)
167        assert_eq()
168
169        adapter.remove_with_event(e1)
170        assert_eq()
171
172        self._test_empty_init(typecallable, creator=creator)
173
174    def _test_empty_init(self, typecallable, creator=None):
175        if creator is None:
176            creator = self.entity_maker
177
178        class Foo(object):
179            pass
180
181        instrumentation.register_class(Foo)
182        attributes.register_attribute(
183            Foo,
184            "attr",
185            uselist=True,
186            typecallable=typecallable,
187            useobject=True,
188        )
189
190        obj = Foo()
191        e1 = creator()
192        e2 = creator()
193        implicit_collection = obj.attr
194        is_true("attr" not in obj.__dict__)
195        adapter = collections.collection_adapter(implicit_collection)
196        is_true(adapter.empty)
197        assert_raises_message(
198            sa_exc.InvalidRequestError,
199            "This is a special 'empty'",
200            adapter.append_without_event,
201            e1,
202        )
203
204        adapter.append_with_event(e1)
205        is_false(adapter.empty)
206        is_true("attr" in obj.__dict__)
207        adapter.append_without_event(e2)
208        eq_(set(adapter), {e1, e2})
209
210    def _test_list(self, typecallable, creator=None):
211        if creator is None:
212            creator = self.entity_maker
213
214        class Foo(object):
215            pass
216
217        canary = Canary()
218        instrumentation.register_class(Foo)
219        d = attributes.register_attribute(
220            Foo,
221            "attr",
222            uselist=True,
223            typecallable=typecallable,
224            useobject=True,
225        )
226        canary.listen(d)
227
228        obj = Foo()
229        adapter = collections.collection_adapter(obj.attr)
230        direct = obj.attr
231        control = list()
232
233        def assert_eq():
234            eq_(set(direct), canary.data)
235            eq_(set(adapter), canary.data)
236            eq_(direct, control)
237
238        # assume append() is available for list tests
239        e = creator()
240        direct.append(e)
241        control.append(e)
242        assert_eq()
243
244        if hasattr(direct, "pop"):
245            direct.pop()
246            control.pop()
247            assert_eq()
248
249        if hasattr(direct, "__setitem__"):
250            e = creator()
251            direct.append(e)
252            control.append(e)
253
254            e = creator()
255            direct[0] = e
256            control[0] = e
257            assert_eq()
258
259            if util.reduce(
260                and_,
261                [
262                    hasattr(direct, a)
263                    for a in ("__delitem__", "insert", "__len__")
264                ],
265                True,
266            ):
267                values = [creator(), creator(), creator(), creator()]
268                direct[slice(0, 1)] = values
269                control[slice(0, 1)] = values
270                assert_eq()
271
272                values = [creator(), creator()]
273                direct[slice(0, -1, 2)] = values
274                control[slice(0, -1, 2)] = values
275                assert_eq()
276
277                values = [creator()]
278                direct[slice(0, -1)] = values
279                control[slice(0, -1)] = values
280                assert_eq()
281
282                values = [creator(), creator(), creator()]
283                control[:] = values
284                direct[:] = values
285
286                def invalid():
287                    direct[slice(0, 6, 2)] = [creator()]
288
289                assert_raises(ValueError, invalid)
290
291        if hasattr(direct, "__delitem__"):
292            e = creator()
293            direct.append(e)
294            control.append(e)
295            del direct[-1]
296            del control[-1]
297            assert_eq()
298
299            if hasattr(direct, "__getslice__"):
300                for e in [creator(), creator(), creator(), creator()]:
301                    direct.append(e)
302                    control.append(e)
303
304                del direct[:-3]
305                del control[:-3]
306                assert_eq()
307
308                del direct[0:1]
309                del control[0:1]
310                assert_eq()
311
312                del direct[::2]
313                del control[::2]
314                assert_eq()
315
316        if hasattr(direct, "remove"):
317            e = creator()
318            direct.append(e)
319            control.append(e)
320
321            direct.remove(e)
322            control.remove(e)
323            assert_eq()
324
325        if hasattr(direct, "__setitem__") or hasattr(direct, "__setslice__"):
326
327            values = [creator(), creator()]
328            direct[:] = values
329            control[:] = values
330            assert_eq()
331
332            # test slice assignment where we slice assign to self,
333            # currently a no-op, issue #4990
334            # note that in py2k, the bug does not exist but it recreates
335            # the collection which breaks our fixtures here
336            with canary.defer_dupe_check():
337                direct[:] = direct
338                control[:] = control
339            assert_eq()
340
341            # we dont handle assignment of self to slices, as this
342            # implies duplicate entries.  behavior here is not well defined
343            # and perhaps should emit a warning
344            # direct[0:1] = list(direct)
345            # control[0:1] = list(control)
346            # assert_eq()
347
348            # test slice assignment where
349            # slice size goes over the number of items
350            values = [creator(), creator()]
351            direct[1:3] = values
352            control[1:3] = values
353            assert_eq()
354
355            values = [creator(), creator()]
356            direct[0:1] = values
357            control[0:1] = values
358            assert_eq()
359
360            values = [creator()]
361            direct[0:] = values
362            control[0:] = values
363            assert_eq()
364
365            values = [creator()]
366            direct[:1] = values
367            control[:1] = values
368            assert_eq()
369
370            values = [creator()]
371            direct[-1::2] = values
372            control[-1::2] = values
373            assert_eq()
374
375            values = [creator()] * len(direct[1::2])
376            direct[1::2] = values
377            control[1::2] = values
378            assert_eq()
379
380            values = [creator(), creator()]
381            direct[-1:-3] = values
382            control[-1:-3] = values
383            assert_eq()
384
385            values = [creator(), creator()]
386            direct[-2:-1] = values
387            control[-2:-1] = values
388            assert_eq()
389
390            values = [creator()]
391            direct[0:0] = values
392            control[0:0] = values
393            assert_eq()
394
395        if hasattr(direct, "__delitem__") or hasattr(direct, "__delslice__"):
396            for i in range(1, 4):
397                e = creator()
398                direct.append(e)
399                control.append(e)
400
401            del direct[-1:]
402            del control[-1:]
403            assert_eq()
404
405            del direct[1:2]
406            del control[1:2]
407            assert_eq()
408
409            del direct[:]
410            del control[:]
411            assert_eq()
412
413        if hasattr(direct, "clear"):
414            for i in range(1, 4):
415                e = creator()
416                direct.append(e)
417                control.append(e)
418
419            direct.clear()
420            control.clear()
421            assert_eq()
422
423        if hasattr(direct, "extend"):
424            values = [creator(), creator(), creator()]
425
426            direct.extend(values)
427            control.extend(values)
428            assert_eq()
429
430        if hasattr(direct, "__iadd__"):
431            values = [creator(), creator(), creator()]
432
433            direct += values
434            control += values
435            assert_eq()
436
437            direct += []
438            control += []
439            assert_eq()
440
441            values = [creator(), creator()]
442            obj.attr += values
443            control += values
444            assert_eq()
445
446        if hasattr(direct, "__imul__"):
447            direct *= 2
448            control *= 2
449            assert_eq()
450
451            obj.attr *= 2
452            control *= 2
453            assert_eq()
454
455    def _test_list_bulk(self, typecallable, creator=None):
456        if creator is None:
457            creator = self.entity_maker
458
459        class Foo(object):
460            pass
461
462        canary = Canary()
463        instrumentation.register_class(Foo)
464        d = attributes.register_attribute(
465            Foo,
466            "attr",
467            uselist=True,
468            typecallable=typecallable,
469            useobject=True,
470        )
471        canary.listen(d)
472
473        obj = Foo()
474        direct = obj.attr
475
476        e1 = creator()
477        obj.attr.append(e1)
478
479        like_me = typecallable()
480        e2 = creator()
481        like_me.append(e2)
482
483        self.assert_(obj.attr is direct)
484        obj.attr = like_me
485        self.assert_(obj.attr is not direct)
486        self.assert_(obj.attr is not like_me)
487        self.assert_(set(obj.attr) == set([e2]))
488        self.assert_(e1 in canary.removed)
489        self.assert_(e2 in canary.added)
490
491        e3 = creator()
492        real_list = [e3]
493        obj.attr = real_list
494        self.assert_(obj.attr is not real_list)
495        self.assert_(set(obj.attr) == set([e3]))
496        self.assert_(e2 in canary.removed)
497        self.assert_(e3 in canary.added)
498
499        e4 = creator()
500        try:
501            obj.attr = set([e4])
502            self.assert_(False)
503        except TypeError:
504            self.assert_(e4 not in canary.data)
505            self.assert_(e3 in canary.data)
506
507        e5 = creator()
508        e6 = creator()
509        e7 = creator()
510        obj.attr = [e5, e6, e7]
511        self.assert_(e5 in canary.added)
512        self.assert_(e6 in canary.added)
513        self.assert_(e7 in canary.added)
514
515        obj.attr = [e6, e7]
516        self.assert_(e5 in canary.removed)
517        self.assert_(e6 in canary.added)
518        self.assert_(e7 in canary.added)
519        self.assert_(e6 not in canary.removed)
520        self.assert_(e7 not in canary.removed)
521
522    def test_list(self):
523        self._test_adapter(list)
524        self._test_list(list)
525        self._test_list_bulk(list)
526
527    def test_list_setitem_with_slices(self):
528
529        # this is a "list" that has no __setslice__
530        # or __delslice__ methods.  The __setitem__
531        # and __delitem__ must therefore accept
532        # slice objects (i.e. as in py3k)
533        class ListLike(object):
534            def __init__(self):
535                self.data = list()
536
537            def append(self, item):
538                self.data.append(item)
539
540            def remove(self, item):
541                self.data.remove(item)
542
543            def insert(self, index, item):
544                self.data.insert(index, item)
545
546            def pop(self, index=-1):
547                return self.data.pop(index)
548
549            def extend(self):
550                assert False
551
552            def __len__(self):
553                return len(self.data)
554
555            def __setitem__(self, key, value):
556                self.data[key] = value
557
558            def __getitem__(self, key):
559                return self.data[key]
560
561            def __delitem__(self, key):
562                del self.data[key]
563
564            def __iter__(self):
565                return iter(self.data)
566
567            __hash__ = object.__hash__
568
569            def __eq__(self, other):
570                return self.data == other
571
572            def __repr__(self):
573                return "ListLike(%s)" % repr(self.data)
574
575        self._test_adapter(ListLike)
576        self._test_list(ListLike)
577        self._test_list_bulk(ListLike)
578
579    def test_list_subclass(self):
580        class MyList(list):
581            pass
582
583        self._test_adapter(MyList)
584        self._test_list(MyList)
585        self._test_list_bulk(MyList)
586        self.assert_(getattr(MyList, "_sa_instrumented") == id(MyList))
587
588    def test_list_duck(self):
589        class ListLike(object):
590            def __init__(self):
591                self.data = list()
592
593            def append(self, item):
594                self.data.append(item)
595
596            def remove(self, item):
597                self.data.remove(item)
598
599            def insert(self, index, item):
600                self.data.insert(index, item)
601
602            def pop(self, index=-1):
603                return self.data.pop(index)
604
605            def extend(self):
606                assert False
607
608            def __iter__(self):
609                return iter(self.data)
610
611            __hash__ = object.__hash__
612
613            def __eq__(self, other):
614                return self.data == other
615
616            def __repr__(self):
617                return "ListLike(%s)" % repr(self.data)
618
619        self._test_adapter(ListLike)
620        self._test_list(ListLike)
621        self._test_list_bulk(ListLike)
622        self.assert_(getattr(ListLike, "_sa_instrumented") == id(ListLike))
623
624    def test_list_emulates(self):
625        class ListIsh(object):
626            __emulates__ = list
627
628            def __init__(self):
629                self.data = list()
630
631            def append(self, item):
632                self.data.append(item)
633
634            def remove(self, item):
635                self.data.remove(item)
636
637            def insert(self, index, item):
638                self.data.insert(index, item)
639
640            def pop(self, index=-1):
641                return self.data.pop(index)
642
643            def extend(self):
644                assert False
645
646            def __iter__(self):
647                return iter(self.data)
648
649            __hash__ = object.__hash__
650
651            def __eq__(self, other):
652                return self.data == other
653
654            def __repr__(self):
655                return "ListIsh(%s)" % repr(self.data)
656
657        self._test_adapter(ListIsh)
658        self._test_list(ListIsh)
659        self._test_list_bulk(ListIsh)
660        self.assert_(getattr(ListIsh, "_sa_instrumented") == id(ListIsh))
661
662    def _test_set_wo_mutation(self, typecallable, creator=None):
663        if creator is None:
664            creator = self.entity_maker
665
666        class Foo(object):
667            pass
668
669        canary = Canary()
670        instrumentation.register_class(Foo)
671        d = attributes.register_attribute(
672            Foo,
673            "attr",
674            uselist=True,
675            typecallable=typecallable,
676            useobject=True,
677        )
678        canary.listen(d)
679
680        obj = Foo()
681
682        e = creator()
683
684        obj.attr.add(e)
685
686        assert e in canary.added
687        assert e not in canary.appended_wo_mutation
688
689        obj.attr.add(e)
690        assert e in canary.added
691        assert e in canary.appended_wo_mutation
692
693        e = creator()
694
695        obj.attr.update({e})
696
697        assert e in canary.added
698        assert e not in canary.appended_wo_mutation
699
700        obj.attr.update({e})
701        assert e in canary.added
702        assert e in canary.appended_wo_mutation
703
704    def _test_set(self, typecallable, creator=None):
705        if creator is None:
706            creator = self.entity_maker
707
708        class Foo(object):
709            pass
710
711        canary = Canary()
712        instrumentation.register_class(Foo)
713        d = attributes.register_attribute(
714            Foo,
715            "attr",
716            uselist=True,
717            typecallable=typecallable,
718            useobject=True,
719        )
720        canary.listen(d)
721
722        obj = Foo()
723        adapter = collections.collection_adapter(obj.attr)
724        direct = obj.attr
725        control = set()
726
727        def assert_eq():
728            eq_(set(direct), canary.data)
729            eq_(set(adapter), canary.data)
730            eq_(direct, control)
731
732        def addall(*values):
733            for item in values:
734                direct.add(item)
735                control.add(item)
736            assert_eq()
737
738        def zap():
739            for item in list(direct):
740                direct.remove(item)
741            control.clear()
742
743        addall(creator())
744
745        e = creator()
746        addall(e)
747        addall(e)
748
749        if hasattr(direct, "remove"):
750            e = creator()
751            addall(e)
752
753            direct.remove(e)
754            control.remove(e)
755            assert_eq()
756
757            e = creator()
758            try:
759                direct.remove(e)
760            except KeyError:
761                assert_eq()
762                self.assert_(e not in canary.removed)
763            else:
764                self.assert_(False)
765
766        if hasattr(direct, "discard"):
767            e = creator()
768            addall(e)
769
770            direct.discard(e)
771            control.discard(e)
772            assert_eq()
773
774            e = creator()
775            direct.discard(e)
776            self.assert_(e not in canary.removed)
777            assert_eq()
778
779        if hasattr(direct, "update"):
780            zap()
781            e = creator()
782            addall(e)
783
784            values = set([e, creator(), creator()])
785
786            direct.update(values)
787            control.update(values)
788            assert_eq()
789
790        if hasattr(direct, "__ior__"):
791            zap()
792            e = creator()
793            addall(e)
794
795            values = set([e, creator(), creator()])
796
797            direct |= values
798            control |= values
799            assert_eq()
800
801            # cover self-assignment short-circuit
802            values = set([e, creator(), creator()])
803            obj.attr |= values
804            control |= values
805            assert_eq()
806
807            values = frozenset([e, creator()])
808            obj.attr |= values
809            control |= values
810            assert_eq()
811
812            try:
813                direct |= [e, creator()]
814                assert False
815            except TypeError:
816                assert True
817
818        addall(creator(), creator())
819        direct.clear()
820        control.clear()
821        assert_eq()
822
823        # note: the clear test previously needs
824        # to have executed in order for this to
825        # pass in all cases; else there's the possibility
826        # of non-deterministic behavior.
827        addall(creator())
828        direct.pop()
829        control.pop()
830        assert_eq()
831
832        if hasattr(direct, "difference_update"):
833            zap()
834            e = creator()
835            addall(creator(), creator())
836            values = set([creator()])
837
838            direct.difference_update(values)
839            control.difference_update(values)
840            assert_eq()
841            values.update(set([e, creator()]))
842            direct.difference_update(values)
843            control.difference_update(values)
844            assert_eq()
845
846        if hasattr(direct, "__isub__"):
847            zap()
848            e = creator()
849            addall(creator(), creator())
850            values = set([creator()])
851
852            direct -= values
853            control -= values
854            assert_eq()
855            values.update(set([e, creator()]))
856            direct -= values
857            control -= values
858            assert_eq()
859
860            values = set([creator()])
861            obj.attr -= values
862            control -= values
863            assert_eq()
864
865            values = frozenset([creator()])
866            obj.attr -= values
867            control -= values
868            assert_eq()
869
870            try:
871                direct -= [e, creator()]
872                assert False
873            except TypeError:
874                assert True
875
876        if hasattr(direct, "intersection_update"):
877            zap()
878            e = creator()
879            addall(e, creator(), creator())
880            values = set(control)
881
882            direct.intersection_update(values)
883            control.intersection_update(values)
884            assert_eq()
885
886            values.update(set([e, creator()]))
887            direct.intersection_update(values)
888            control.intersection_update(values)
889            assert_eq()
890
891        if hasattr(direct, "__iand__"):
892            zap()
893            e = creator()
894            addall(e, creator(), creator())
895            values = set(control)
896
897            direct &= values
898            control &= values
899            assert_eq()
900
901            values.update(set([e, creator()]))
902            direct &= values
903            control &= values
904            assert_eq()
905
906            values.update(set([creator()]))
907            obj.attr &= values
908            control &= values
909            assert_eq()
910
911            try:
912                direct &= [e, creator()]
913                assert False
914            except TypeError:
915                assert True
916
917        if hasattr(direct, "symmetric_difference_update"):
918            zap()
919            e = creator()
920            addall(e, creator(), creator())
921
922            values = set([e, creator()])
923            direct.symmetric_difference_update(values)
924            control.symmetric_difference_update(values)
925            assert_eq()
926
927            e = creator()
928            addall(e)
929            values = set([e])
930            direct.symmetric_difference_update(values)
931            control.symmetric_difference_update(values)
932            assert_eq()
933
934            values = set()
935            direct.symmetric_difference_update(values)
936            control.symmetric_difference_update(values)
937            assert_eq()
938
939        if hasattr(direct, "__ixor__"):
940            zap()
941            e = creator()
942            addall(e, creator(), creator())
943
944            values = set([e, creator()])
945            direct ^= values
946            control ^= values
947            assert_eq()
948
949            e = creator()
950            addall(e)
951            values = set([e])
952            direct ^= values
953            control ^= values
954            assert_eq()
955
956            values = set()
957            direct ^= values
958            control ^= values
959            assert_eq()
960
961            values = set([creator()])
962            obj.attr ^= values
963            control ^= values
964            assert_eq()
965
966            try:
967                direct ^= [e, creator()]
968                assert False
969            except TypeError:
970                assert True
971
972    def _test_set_bulk(self, typecallable, creator=None):
973        if creator is None:
974            creator = self.entity_maker
975
976        class Foo(object):
977            pass
978
979        canary = Canary()
980        instrumentation.register_class(Foo)
981        d = attributes.register_attribute(
982            Foo,
983            "attr",
984            uselist=True,
985            typecallable=typecallable,
986            useobject=True,
987        )
988        canary.listen(d)
989
990        obj = Foo()
991        direct = obj.attr
992
993        e1 = creator()
994        obj.attr.add(e1)
995
996        like_me = typecallable()
997        e2 = creator()
998        like_me.add(e2)
999
1000        self.assert_(obj.attr is direct)
1001        obj.attr = like_me
1002        self.assert_(obj.attr is not direct)
1003        self.assert_(obj.attr is not like_me)
1004        self.assert_(obj.attr == set([e2]))
1005        self.assert_(e1 in canary.removed)
1006        self.assert_(e2 in canary.added)
1007
1008        e3 = creator()
1009        real_set = set([e3])
1010        obj.attr = real_set
1011        self.assert_(obj.attr is not real_set)
1012        self.assert_(obj.attr == set([e3]))
1013        self.assert_(e2 in canary.removed)
1014        self.assert_(e3 in canary.added)
1015
1016        e4 = creator()
1017        try:
1018            obj.attr = [e4]
1019            self.assert_(False)
1020        except TypeError:
1021            self.assert_(e4 not in canary.data)
1022            self.assert_(e3 in canary.data)
1023
1024    def test_set(self):
1025        self._test_adapter(set)
1026        self._test_set(set)
1027        self._test_set_bulk(set)
1028        self._test_set_wo_mutation(set)
1029
1030    def test_set_subclass(self):
1031        class MySet(set):
1032            pass
1033
1034        self._test_adapter(MySet)
1035        self._test_set(MySet)
1036        self._test_set_bulk(MySet)
1037        self.assert_(getattr(MySet, "_sa_instrumented") == id(MySet))
1038
1039    def test_set_duck(self):
1040        class SetLike(object):
1041            def __init__(self):
1042                self.data = set()
1043
1044            def add(self, item):
1045                self.data.add(item)
1046
1047            def remove(self, item):
1048                self.data.remove(item)
1049
1050            def discard(self, item):
1051                self.data.discard(item)
1052
1053            def clear(self):
1054                self.data.clear()
1055
1056            def pop(self):
1057                return self.data.pop()
1058
1059            def update(self, other):
1060                self.data.update(other)
1061
1062            def __iter__(self):
1063                return iter(self.data)
1064
1065            __hash__ = object.__hash__
1066
1067            def __eq__(self, other):
1068                return self.data == other
1069
1070        self._test_adapter(SetLike)
1071        self._test_set(SetLike)
1072        self._test_set_bulk(SetLike)
1073        self.assert_(getattr(SetLike, "_sa_instrumented") == id(SetLike))
1074
1075    def test_set_emulates(self):
1076        class SetIsh(object):
1077            __emulates__ = set
1078
1079            def __init__(self):
1080                self.data = set()
1081
1082            def add(self, item):
1083                self.data.add(item)
1084
1085            def remove(self, item):
1086                self.data.remove(item)
1087
1088            def discard(self, item):
1089                self.data.discard(item)
1090
1091            def pop(self):
1092                return self.data.pop()
1093
1094            def update(self, other):
1095                self.data.update(other)
1096
1097            def __iter__(self):
1098                return iter(self.data)
1099
1100            def clear(self):
1101                self.data.clear()
1102
1103            __hash__ = object.__hash__
1104
1105            def __eq__(self, other):
1106                return self.data == other
1107
1108        self._test_adapter(SetIsh)
1109        self._test_set(SetIsh)
1110        self._test_set_bulk(SetIsh)
1111        self.assert_(getattr(SetIsh, "_sa_instrumented") == id(SetIsh))
1112
1113    def _test_dict_wo_mutation(self, typecallable, creator=None):
1114        if creator is None:
1115            creator = self.dictable_entity
1116
1117        class Foo(object):
1118            pass
1119
1120        canary = Canary()
1121        instrumentation.register_class(Foo)
1122        d = attributes.register_attribute(
1123            Foo,
1124            "attr",
1125            uselist=True,
1126            typecallable=typecallable,
1127            useobject=True,
1128        )
1129        canary.listen(d)
1130
1131        obj = Foo()
1132
1133        e = creator()
1134
1135        obj.attr[e.a] = e
1136        assert e in canary.added
1137        assert e not in canary.appended_wo_mutation
1138
1139        with canary.defer_dupe_check():
1140            # __setitem__ sets every time
1141            obj.attr[e.a] = e
1142            assert e in canary.added
1143            assert e not in canary.appended_wo_mutation
1144
1145        if hasattr(obj.attr, "update"):
1146            e = creator()
1147            obj.attr.update({e.a: e})
1148            assert e in canary.added
1149            assert e not in canary.appended_wo_mutation
1150
1151            obj.attr.update({e.a: e})
1152            assert e in canary.added
1153            assert e in canary.appended_wo_mutation
1154
1155            e = creator()
1156            obj.attr.update(**{e.a: e})
1157            assert e in canary.added
1158            assert e not in canary.appended_wo_mutation
1159
1160            obj.attr.update(**{e.a: e})
1161            assert e in canary.added
1162            assert e in canary.appended_wo_mutation
1163
1164        if hasattr(obj.attr, "setdefault"):
1165            e = creator()
1166            obj.attr.setdefault(e.a, e)
1167            assert e in canary.added
1168            assert e not in canary.appended_wo_mutation
1169
1170            obj.attr.setdefault(e.a, e)
1171            assert e in canary.added
1172            assert e in canary.appended_wo_mutation
1173
1174    def _test_dict(self, typecallable, creator=None):
1175        if creator is None:
1176            creator = self.dictable_entity
1177
1178        class Foo(object):
1179            pass
1180
1181        canary = Canary()
1182        instrumentation.register_class(Foo)
1183        d = attributes.register_attribute(
1184            Foo,
1185            "attr",
1186            uselist=True,
1187            typecallable=typecallable,
1188            useobject=True,
1189        )
1190        canary.listen(d)
1191
1192        obj = Foo()
1193        adapter = collections.collection_adapter(obj.attr)
1194        direct = obj.attr
1195        control = dict()
1196
1197        def assert_eq():
1198            self.assert_(set(direct.values()) == canary.data)
1199            self.assert_(set(adapter) == canary.data)
1200            self.assert_(direct == control)
1201
1202        def addall(*values):
1203            for item in values:
1204                direct.set(item)
1205                control[item.a] = item
1206            assert_eq()
1207
1208        def zap():
1209            for item in list(adapter):
1210                direct.remove(item)
1211            control.clear()
1212
1213        # assume an 'set' method is available for tests
1214        addall(creator())
1215
1216        if hasattr(direct, "__setitem__"):
1217            e = creator()
1218            direct[e.a] = e
1219            control[e.a] = e
1220            assert_eq()
1221
1222            e = creator(e.a, e.b)
1223            direct[e.a] = e
1224            control[e.a] = e
1225            assert_eq()
1226
1227        if hasattr(direct, "__delitem__"):
1228            e = creator()
1229            addall(e)
1230
1231            del direct[e.a]
1232            del control[e.a]
1233            assert_eq()
1234
1235            e = creator()
1236            try:
1237                del direct[e.a]
1238            except KeyError:
1239                self.assert_(e not in canary.removed)
1240
1241        if hasattr(direct, "clear"):
1242            addall(creator(), creator(), creator())
1243
1244            direct.clear()
1245            control.clear()
1246            assert_eq()
1247
1248            direct.clear()
1249            control.clear()
1250            assert_eq()
1251
1252        if hasattr(direct, "pop"):
1253            e = creator()
1254            addall(e)
1255
1256            direct.pop(e.a)
1257            control.pop(e.a)
1258            assert_eq()
1259
1260            e = creator()
1261            try:
1262                direct.pop(e.a)
1263            except KeyError:
1264                self.assert_(e not in canary.removed)
1265
1266        if hasattr(direct, "popitem"):
1267            zap()
1268            e = creator()
1269            addall(e)
1270
1271            direct.popitem()
1272            control.popitem()
1273            assert_eq()
1274
1275        if hasattr(direct, "setdefault"):
1276            e = creator()
1277
1278            val_a = direct.setdefault(e.a, e)
1279            val_b = control.setdefault(e.a, e)
1280            assert_eq()
1281            self.assert_(val_a is val_b)
1282
1283            val_a = direct.setdefault(e.a, e)
1284            val_b = control.setdefault(e.a, e)
1285            assert_eq()
1286            self.assert_(val_a is val_b)
1287
1288        if hasattr(direct, "update"):
1289            e = creator()
1290            d = dict([(ee.a, ee) for ee in [e, creator(), creator()]])
1291            addall(e, creator())
1292
1293            direct.update(d)
1294            control.update(d)
1295            assert_eq()
1296
1297            kw = dict([(ee.a, ee) for ee in [e, creator()]])
1298            direct.update(**kw)
1299            control.update(**kw)
1300            assert_eq()
1301
1302    def _test_dict_bulk(self, typecallable, creator=None):
1303        if creator is None:
1304            creator = self.dictable_entity
1305
1306        class Foo(object):
1307            pass
1308
1309        canary = Canary()
1310        instrumentation.register_class(Foo)
1311        d = attributes.register_attribute(
1312            Foo,
1313            "attr",
1314            uselist=True,
1315            typecallable=typecallable,
1316            useobject=True,
1317        )
1318        canary.listen(d)
1319
1320        obj = Foo()
1321        direct = obj.attr
1322
1323        e1 = creator()
1324        collections.collection_adapter(direct).append_with_event(e1)
1325
1326        like_me = typecallable()
1327        e2 = creator()
1328        like_me.set(e2)
1329
1330        self.assert_(obj.attr is direct)
1331        obj.attr = like_me
1332        self.assert_(obj.attr is not direct)
1333        self.assert_(obj.attr is not like_me)
1334        self.assert_(
1335            set(collections.collection_adapter(obj.attr)) == set([e2])
1336        )
1337        self.assert_(e1 in canary.removed)
1338        self.assert_(e2 in canary.added)
1339
1340        # key validity on bulk assignment is a basic feature of
1341        # MappedCollection but is not present in basic, @converter-less
1342        # dict collections.
1343        e3 = creator()
1344        real_dict = dict(keyignored1=e3)
1345        obj.attr = real_dict
1346        self.assert_(obj.attr is not real_dict)
1347        self.assert_("keyignored1" not in obj.attr)
1348        eq_(set(collections.collection_adapter(obj.attr)), set([e3]))
1349        self.assert_(e2 in canary.removed)
1350        self.assert_(e3 in canary.added)
1351
1352        obj.attr = typecallable()
1353        eq_(list(collections.collection_adapter(obj.attr)), [])
1354
1355        e4 = creator()
1356        try:
1357            obj.attr = [e4]
1358            self.assert_(False)
1359        except TypeError:
1360            self.assert_(e4 not in canary.data)
1361
1362    def test_dict(self):
1363        assert_raises_message(
1364            sa_exc.ArgumentError,
1365            "Type InstrumentedDict must elect an appender "
1366            "method to be a collection class",
1367            self._test_adapter,
1368            dict,
1369            self.dictable_entity,
1370            to_set=lambda c: set(c.values()),
1371        )
1372
1373        assert_raises_message(
1374            sa_exc.ArgumentError,
1375            "Type InstrumentedDict must elect an appender method "
1376            "to be a collection class",
1377            self._test_dict,
1378            dict,
1379        )
1380
1381    def test_dict_subclass(self):
1382        class MyDict(dict):
1383            @collection.appender
1384            @collection.internally_instrumented
1385            def set(self, item, _sa_initiator=None):
1386                self.__setitem__(item.a, item, _sa_initiator=_sa_initiator)
1387
1388            @collection.remover
1389            @collection.internally_instrumented
1390            def _remove(self, item, _sa_initiator=None):
1391                self.__delitem__(item.a, _sa_initiator=_sa_initiator)
1392
1393        self._test_adapter(
1394            MyDict, self.dictable_entity, to_set=lambda c: set(c.values())
1395        )
1396        self._test_dict(MyDict)
1397        self._test_dict_bulk(MyDict)
1398        self._test_dict_wo_mutation(MyDict)
1399        self.assert_(getattr(MyDict, "_sa_instrumented") == id(MyDict))
1400
1401    def test_dict_subclass2(self):
1402        class MyEasyDict(collections.MappedCollection):
1403            def __init__(self):
1404                super(MyEasyDict, self).__init__(lambda e: e.a)
1405
1406        self._test_adapter(
1407            MyEasyDict, self.dictable_entity, to_set=lambda c: set(c.values())
1408        )
1409        self._test_dict(MyEasyDict)
1410        self._test_dict_bulk(MyEasyDict)
1411        self._test_dict_wo_mutation(MyEasyDict)
1412        self.assert_(getattr(MyEasyDict, "_sa_instrumented") == id(MyEasyDict))
1413
1414    def test_dict_subclass3(self, ordered_dict_mro):
1415        class MyOrdered(ordered_dict_mro):
1416            def __init__(self):
1417                collections.MappedCollection.__init__(self, lambda e: e.a)
1418                util.OrderedDict.__init__(self)
1419
1420        self._test_adapter(
1421            MyOrdered, self.dictable_entity, to_set=lambda c: set(c.values())
1422        )
1423        self._test_dict(MyOrdered)
1424        self._test_dict_bulk(MyOrdered)
1425        self._test_dict_wo_mutation(MyOrdered)
1426        self.assert_(getattr(MyOrdered, "_sa_instrumented") == id(MyOrdered))
1427
1428    def test_dict_duck(self):
1429        class DictLike(object):
1430            def __init__(self):
1431                self.data = dict()
1432
1433            @collection.appender
1434            @collection.replaces(1)
1435            def set(self, item):
1436                current = self.data.get(item.a, None)
1437                self.data[item.a] = item
1438                return current
1439
1440            @collection.remover
1441            def _remove(self, item):
1442                del self.data[item.a]
1443
1444            def __setitem__(self, key, value):
1445                self.data[key] = value
1446
1447            def __getitem__(self, key):
1448                return self.data[key]
1449
1450            def __delitem__(self, key):
1451                del self.data[key]
1452
1453            def values(self):
1454                return list(self.data.values())
1455
1456            def __contains__(self, key):
1457                return key in self.data
1458
1459            @collection.iterator
1460            def itervalues(self):
1461                return iter(self.data.values())
1462
1463            __hash__ = object.__hash__
1464
1465            def __eq__(self, other):
1466                return self.data == other
1467
1468            def __repr__(self):
1469                return "DictLike(%s)" % repr(self.data)
1470
1471        self._test_adapter(
1472            DictLike, self.dictable_entity, to_set=lambda c: set(c.values())
1473        )
1474        self._test_dict(DictLike)
1475        self._test_dict_bulk(DictLike)
1476        self._test_dict_wo_mutation(DictLike)
1477        self.assert_(getattr(DictLike, "_sa_instrumented") == id(DictLike))
1478
1479    def test_dict_emulates(self):
1480        class DictIsh(object):
1481            __emulates__ = dict
1482
1483            def __init__(self):
1484                self.data = dict()
1485
1486            @collection.appender
1487            @collection.replaces(1)
1488            def set(self, item):
1489                current = self.data.get(item.a, None)
1490                self.data[item.a] = item
1491                return current
1492
1493            @collection.remover
1494            def _remove(self, item):
1495                del self.data[item.a]
1496
1497            def __setitem__(self, key, value):
1498                self.data[key] = value
1499
1500            def __getitem__(self, key):
1501                return self.data[key]
1502
1503            def __delitem__(self, key):
1504                del self.data[key]
1505
1506            def values(self):
1507                return list(self.data.values())
1508
1509            def __contains__(self, key):
1510                return key in self.data
1511
1512            @collection.iterator
1513            def itervalues(self):
1514                return iter(self.data.values())
1515
1516            __hash__ = object.__hash__
1517
1518            def __eq__(self, other):
1519                return self.data == other
1520
1521            def __repr__(self):
1522                return "DictIsh(%s)" % repr(self.data)
1523
1524        self._test_adapter(
1525            DictIsh, self.dictable_entity, to_set=lambda c: set(c.values())
1526        )
1527        self._test_dict(DictIsh)
1528        self._test_dict_bulk(DictIsh)
1529        self._test_dict_wo_mutation(DictIsh)
1530        self.assert_(getattr(DictIsh, "_sa_instrumented") == id(DictIsh))
1531
1532    def _test_object(self, typecallable, creator=None):
1533        if creator is None:
1534            creator = self.entity_maker
1535
1536        class Foo(object):
1537            pass
1538
1539        canary = Canary()
1540        instrumentation.register_class(Foo)
1541        d = attributes.register_attribute(
1542            Foo,
1543            "attr",
1544            uselist=True,
1545            typecallable=typecallable,
1546            useobject=True,
1547        )
1548        canary.listen(d)
1549
1550        obj = Foo()
1551        adapter = collections.collection_adapter(obj.attr)
1552        direct = obj.attr
1553        control = set()
1554
1555        def assert_eq():
1556            self.assert_(set(direct) == canary.data)
1557            self.assert_(set(adapter) == canary.data)
1558            self.assert_(direct == control)
1559
1560        # There is no API for object collections.  We'll make one up
1561        # for the purposes of the test.
1562        e = creator()
1563        direct.push(e)
1564        control.add(e)
1565        assert_eq()
1566
1567        direct.zark(e)
1568        control.remove(e)
1569        assert_eq()
1570
1571        e = creator()
1572        direct.maybe_zark(e)
1573        control.discard(e)
1574        assert_eq()
1575
1576        e = creator()
1577        direct.push(e)
1578        control.add(e)
1579        assert_eq()
1580
1581        e = creator()
1582        direct.maybe_zark(e)
1583        control.discard(e)
1584        assert_eq()
1585
1586    def test_object_duck(self):
1587        class MyCollection(object):
1588            def __init__(self):
1589                self.data = set()
1590
1591            @collection.appender
1592            def push(self, item):
1593                self.data.add(item)
1594
1595            @collection.remover
1596            def zark(self, item):
1597                self.data.remove(item)
1598
1599            @collection.removes_return()
1600            def maybe_zark(self, item):
1601                if item in self.data:
1602                    self.data.remove(item)
1603                    return item
1604
1605            @collection.iterator
1606            def __iter__(self):
1607                return iter(self.data)
1608
1609            __hash__ = object.__hash__
1610
1611            def __eq__(self, other):
1612                return self.data == other
1613
1614        self._test_adapter(MyCollection)
1615        self._test_object(MyCollection)
1616        self.assert_(
1617            getattr(MyCollection, "_sa_instrumented") == id(MyCollection)
1618        )
1619
1620    def test_object_emulates(self):
1621        class MyCollection2(object):
1622            __emulates__ = None
1623
1624            def __init__(self):
1625                self.data = set()
1626
1627            # looks like a list
1628
1629            def append(self, item):
1630                assert False
1631
1632            @collection.appender
1633            def push(self, item):
1634                self.data.add(item)
1635
1636            @collection.remover
1637            def zark(self, item):
1638                self.data.remove(item)
1639
1640            @collection.removes_return()
1641            def maybe_zark(self, item):
1642                if item in self.data:
1643                    self.data.remove(item)
1644                    return item
1645
1646            @collection.iterator
1647            def __iter__(self):
1648                return iter(self.data)
1649
1650            __hash__ = object.__hash__
1651
1652            def __eq__(self, other):
1653                return self.data == other
1654
1655        self._test_adapter(MyCollection2)
1656        self._test_object(MyCollection2)
1657        self.assert_(
1658            getattr(MyCollection2, "_sa_instrumented") == id(MyCollection2)
1659        )
1660
1661    def test_recipes(self):
1662        class Custom(object):
1663            def __init__(self):
1664                self.data = []
1665
1666            @collection.appender
1667            @collection.adds("entity")
1668            def put(self, entity):
1669                self.data.append(entity)
1670
1671            @collection.remover
1672            @collection.removes(1)
1673            def remove(self, entity):
1674                self.data.remove(entity)
1675
1676            @collection.adds(1)
1677            def push(self, *args):
1678                self.data.append(args[0])
1679
1680            @collection.removes("entity")
1681            def yank(self, entity, arg):
1682                self.data.remove(entity)
1683
1684            @collection.replaces(2)
1685            def replace(self, arg, entity, **kw):
1686                self.data.insert(0, entity)
1687                return self.data.pop()
1688
1689            @collection.removes_return()
1690            def pop(self, key):
1691                return self.data.pop()
1692
1693            @collection.iterator
1694            def __iter__(self):
1695                return iter(self.data)
1696
1697        class Foo(object):
1698            pass
1699
1700        canary = Canary()
1701        instrumentation.register_class(Foo)
1702        d = attributes.register_attribute(
1703            Foo, "attr", uselist=True, typecallable=Custom, useobject=True
1704        )
1705        canary.listen(d)
1706
1707        obj = Foo()
1708        adapter = collections.collection_adapter(obj.attr)
1709        direct = obj.attr
1710        control = list()
1711
1712        def assert_eq():
1713            self.assert_(set(direct) == canary.data)
1714            self.assert_(set(adapter) == canary.data)
1715            self.assert_(list(direct) == control)
1716
1717        creator = self.entity_maker
1718
1719        e1 = creator()
1720        direct.put(e1)
1721        control.append(e1)
1722        assert_eq()
1723
1724        e2 = creator()
1725        direct.put(entity=e2)
1726        control.append(e2)
1727        assert_eq()
1728
1729        direct.remove(e2)
1730        control.remove(e2)
1731        assert_eq()
1732
1733        direct.remove(entity=e1)
1734        control.remove(e1)
1735        assert_eq()
1736
1737        e3 = creator()
1738        direct.push(e3)
1739        control.append(e3)
1740        assert_eq()
1741
1742        direct.yank(e3, "blah")
1743        control.remove(e3)
1744        assert_eq()
1745
1746        e4, e5, e6, e7 = creator(), creator(), creator(), creator()
1747        direct.put(e4)
1748        direct.put(e5)
1749        control.append(e4)
1750        control.append(e5)
1751
1752        dr1 = direct.replace("foo", e6, bar="baz")
1753        control.insert(0, e6)
1754        cr1 = control.pop()
1755        assert_eq()
1756        self.assert_(dr1 is cr1)
1757
1758        dr2 = direct.replace(arg=1, entity=e7)
1759        control.insert(0, e7)
1760        cr2 = control.pop()
1761        assert_eq()
1762        self.assert_(dr2 is cr2)
1763
1764        dr3 = direct.pop("blah")
1765        cr3 = control.pop()
1766        assert_eq()
1767        self.assert_(dr3 is cr3)
1768
1769    def test_lifecycle(self):
1770        class Foo(object):
1771            pass
1772
1773        canary = Canary()
1774        creator = self.entity_maker
1775        instrumentation.register_class(Foo)
1776        d = attributes.register_attribute(
1777            Foo, "attr", uselist=True, useobject=True
1778        )
1779        canary.listen(d)
1780
1781        obj = Foo()
1782        col1 = obj.attr
1783
1784        e1 = creator()
1785        obj.attr.append(e1)
1786
1787        e2 = creator()
1788        bulk1 = [e2]
1789        # empty & sever col1 from obj
1790        obj.attr = bulk1
1791
1792        # as of [ticket:3913] the old collection
1793        # remains unchanged
1794        self.assert_(len(col1) == 1)
1795
1796        self.assert_(len(canary.data) == 1)
1797        self.assert_(obj.attr is not col1)
1798        self.assert_(obj.attr is not bulk1)
1799        self.assert_(obj.attr == bulk1)
1800
1801        e3 = creator()
1802        col1.append(e3)
1803        self.assert_(e3 not in canary.data)
1804        self.assert_(collections.collection_adapter(col1) is None)
1805
1806        obj.attr[0] = e3
1807        self.assert_(e3 in canary.data)
1808
1809
1810class DictHelpersTest(OrderedDictFixture, fixtures.MappedTest):
1811    @classmethod
1812    def define_tables(cls, metadata):
1813        Table(
1814            "parents",
1815            metadata,
1816            Column(
1817                "id", Integer, primary_key=True, test_needs_autoincrement=True
1818            ),
1819            Column("label", String(128)),
1820        )
1821        Table(
1822            "children",
1823            metadata,
1824            Column(
1825                "id", Integer, primary_key=True, test_needs_autoincrement=True
1826            ),
1827            Column(
1828                "parent_id", Integer, ForeignKey("parents.id"), nullable=False
1829            ),
1830            Column("a", String(128)),
1831            Column("b", String(128)),
1832            Column("c", String(128)),
1833        )
1834
1835    @classmethod
1836    def setup_classes(cls):
1837        class Parent(cls.Basic):
1838            def __init__(self, label=None):
1839                self.label = label
1840
1841        class Child(cls.Basic):
1842            def __init__(self, a=None, b=None, c=None):
1843                self.a = a
1844                self.b = b
1845                self.c = c
1846
1847    def _test_scalar_mapped(self, collection_class):
1848        parents, children, Parent, Child = (
1849            self.tables.parents,
1850            self.tables.children,
1851            self.classes.Parent,
1852            self.classes.Child,
1853        )
1854
1855        self.mapper_registry.map_imperatively(Child, children)
1856        self.mapper_registry.map_imperatively(
1857            Parent,
1858            parents,
1859            properties={
1860                "children": relationship(
1861                    Child,
1862                    collection_class=collection_class,
1863                    cascade="all, delete-orphan",
1864                )
1865            },
1866        )
1867
1868        p = Parent()
1869        p.children["foo"] = Child("foo", "value")
1870        p.children["bar"] = Child("bar", "value")
1871        session = fixture_session()
1872        session.add(p)
1873        session.flush()
1874        pid = p.id
1875        session.expunge_all()
1876
1877        p = session.get(Parent, pid)
1878
1879        eq_(set(p.children.keys()), set(["foo", "bar"]))
1880        cid = p.children["foo"].id
1881
1882        collections.collection_adapter(p.children).append_with_event(
1883            Child("foo", "newvalue")
1884        )
1885
1886        session.flush()
1887        session.expunge_all()
1888
1889        p = session.get(Parent, pid)
1890
1891        self.assert_(set(p.children.keys()) == set(["foo", "bar"]))
1892        self.assert_(p.children["foo"].id != cid)
1893
1894        self.assert_(
1895            len(list(collections.collection_adapter(p.children))) == 2
1896        )
1897        session.flush()
1898        session.expunge_all()
1899
1900        p = session.get(Parent, pid)
1901        self.assert_(
1902            len(list(collections.collection_adapter(p.children))) == 2
1903        )
1904
1905        collections.collection_adapter(p.children).remove_with_event(
1906            p.children["foo"]
1907        )
1908
1909        self.assert_(
1910            len(list(collections.collection_adapter(p.children))) == 1
1911        )
1912        session.flush()
1913        session.expunge_all()
1914
1915        p = session.get(Parent, pid)
1916        self.assert_(
1917            len(list(collections.collection_adapter(p.children))) == 1
1918        )
1919
1920        del p.children["bar"]
1921        self.assert_(
1922            len(list(collections.collection_adapter(p.children))) == 0
1923        )
1924        session.flush()
1925        session.expunge_all()
1926
1927        p = session.get(Parent, pid)
1928        self.assert_(
1929            len(list(collections.collection_adapter(p.children))) == 0
1930        )
1931
1932    def _test_composite_mapped(self, collection_class):
1933        parents, children, Parent, Child = (
1934            self.tables.parents,
1935            self.tables.children,
1936            self.classes.Parent,
1937            self.classes.Child,
1938        )
1939
1940        self.mapper_registry.map_imperatively(Child, children)
1941        self.mapper_registry.map_imperatively(
1942            Parent,
1943            parents,
1944            properties={
1945                "children": relationship(
1946                    Child,
1947                    collection_class=collection_class,
1948                    cascade="all, delete-orphan",
1949                )
1950            },
1951        )
1952
1953        p = Parent()
1954        p.children[("foo", "1")] = Child("foo", "1", "value 1")
1955        p.children[("foo", "2")] = Child("foo", "2", "value 2")
1956
1957        session = fixture_session()
1958        session.add(p)
1959        session.flush()
1960        pid = p.id
1961        session.expunge_all()
1962
1963        p = session.get(Parent, pid)
1964
1965        self.assert_(
1966            set(p.children.keys()) == set([("foo", "1"), ("foo", "2")])
1967        )
1968        cid = p.children[("foo", "1")].id
1969
1970        collections.collection_adapter(p.children).append_with_event(
1971            Child("foo", "1", "newvalue")
1972        )
1973
1974        session.flush()
1975        session.expunge_all()
1976
1977        p = session.get(Parent, pid)
1978
1979        self.assert_(
1980            set(p.children.keys()) == set([("foo", "1"), ("foo", "2")])
1981        )
1982        self.assert_(p.children[("foo", "1")].id != cid)
1983
1984        self.assert_(
1985            len(list(collections.collection_adapter(p.children))) == 2
1986        )
1987
1988    def test_mapped_collection(self):
1989        collection_class = collections.mapped_collection(lambda c: c.a)
1990        self._test_scalar_mapped(collection_class)
1991
1992    def test_mapped_collection2(self):
1993        collection_class = collections.mapped_collection(lambda c: (c.a, c.b))
1994        self._test_composite_mapped(collection_class)
1995
1996    def test_attr_mapped_collection(self):
1997        collection_class = collections.attribute_mapped_collection("a")
1998        self._test_scalar_mapped(collection_class)
1999
2000    def test_declarative_column_mapped(self):
2001        """test that uncompiled attribute usage works with
2002        column_mapped_collection"""
2003
2004        BaseObject = declarative_base()
2005
2006        class Foo(BaseObject):
2007            __tablename__ = "foo"
2008            id = Column(Integer(), primary_key=True)
2009            bar_id = Column(Integer, ForeignKey("bar.id"))
2010
2011        for spec, obj, expected in (
2012            (Foo.id, Foo(id=3), 3),
2013            ((Foo.id, Foo.bar_id), Foo(id=3, bar_id=12), (3, 12)),
2014        ):
2015            eq_(
2016                collections.column_mapped_collection(spec)().keyfunc(obj),
2017                expected,
2018            )
2019
2020    def test_column_mapped_assertions(self):
2021        assert_raises_message(
2022            sa_exc.ArgumentError,
2023            "Column expression expected "
2024            "for argument 'mapping_spec'; got 'a'.",
2025            collections.column_mapped_collection,
2026            "a",
2027        )
2028        assert_raises_message(
2029            sa_exc.ArgumentError,
2030            "Column expression expected "
2031            "for argument 'mapping_spec'; got .*TextClause.",
2032            collections.column_mapped_collection,
2033            text("a"),
2034        )
2035
2036    def test_column_mapped_collection(self):
2037        children = self.tables.children
2038
2039        collection_class = collections.column_mapped_collection(children.c.a)
2040        self._test_scalar_mapped(collection_class)
2041
2042    def test_column_mapped_collection2(self):
2043        children = self.tables.children
2044
2045        collection_class = collections.column_mapped_collection(
2046            (children.c.a, children.c.b)
2047        )
2048        self._test_composite_mapped(collection_class)
2049
2050    def test_mixin(self, ordered_dict_mro):
2051        class Ordered(ordered_dict_mro):
2052            def __init__(self):
2053                collections.MappedCollection.__init__(self, lambda v: v.a)
2054                util.OrderedDict.__init__(self)
2055
2056        collection_class = Ordered
2057        self._test_scalar_mapped(collection_class)
2058
2059    def test_mixin2(self, ordered_dict_mro):
2060        class Ordered2(ordered_dict_mro):
2061            def __init__(self, keyfunc):
2062                collections.MappedCollection.__init__(self, keyfunc)
2063                util.OrderedDict.__init__(self)
2064
2065        def collection_class():
2066            return Ordered2(lambda v: (v.a, v.b))
2067
2068        self._test_composite_mapped(collection_class)
2069
2070
2071class ColumnMappedWSerialize(fixtures.MappedTest):
2072    """test the column_mapped_collection serializer against
2073    multi-table and indirect table edge cases, including
2074    serialization."""
2075
2076    run_create_tables = run_deletes = None
2077
2078    @classmethod
2079    def define_tables(cls, metadata):
2080        Table(
2081            "foo",
2082            metadata,
2083            Column("id", Integer(), primary_key=True),
2084            Column("b", String(128)),
2085        )
2086        Table(
2087            "bar",
2088            metadata,
2089            Column("id", Integer(), primary_key=True),
2090            Column("foo_id", Integer, ForeignKey("foo.id")),
2091            Column("bat_id", Integer),
2092            schema="x",
2093        )
2094
2095    @classmethod
2096    def setup_classes(cls):
2097        class Foo(cls.Basic):
2098            pass
2099
2100        class Bar(Foo):
2101            pass
2102
2103    def test_indirect_table_column_mapped(self):
2104        Foo = self.classes.Foo
2105        Bar = self.classes.Bar
2106        bar = self.tables["x.bar"]
2107        self.mapper_registry.map_imperatively(
2108            Foo, self.tables.foo, properties={"foo_id": self.tables.foo.c.id}
2109        )
2110        self.mapper_registry.map_imperatively(
2111            Bar, bar, inherits=Foo, properties={"bar_id": bar.c.id}
2112        )
2113
2114        bar_spec = Bar(foo_id=1, bar_id=2, bat_id=3)
2115        self._run_test(
2116            [
2117                (Foo.foo_id, bar_spec, 1),
2118                ((Bar.bar_id, Bar.bat_id), bar_spec, (2, 3)),
2119                (Bar.foo_id, bar_spec, 1),
2120                (bar.c.id, bar_spec, 2),
2121            ]
2122        )
2123
2124    def test_selectable_column_mapped(self):
2125        from sqlalchemy import select
2126
2127        s = select(self.tables.foo).alias()
2128        Foo = self.classes.Foo
2129        self.mapper_registry.map_imperatively(Foo, s)
2130        self._run_test([(Foo.b, Foo(b=5), 5), (s.c.b, Foo(b=5), 5)])
2131
2132    def _run_test(self, specs):
2133        from sqlalchemy.testing.util import picklers
2134
2135        for spec, obj, expected in specs:
2136            coll = collections.column_mapped_collection(spec)()
2137            eq_(coll.keyfunc(obj), expected)
2138            # ensure we do the right thing with __reduce__
2139            for loads, dumps in picklers():
2140                c2 = loads(dumps(coll))
2141                eq_(c2.keyfunc(obj), expected)
2142                c3 = loads(dumps(c2))
2143                eq_(c3.keyfunc(obj), expected)
2144
2145
2146class CustomCollectionsTest(fixtures.MappedTest):
2147    """test the integration of collections with mapped classes."""
2148
2149    @classmethod
2150    def define_tables(cls, metadata):
2151        Table(
2152            "sometable",
2153            metadata,
2154            Column(
2155                "col1",
2156                Integer,
2157                primary_key=True,
2158                test_needs_autoincrement=True,
2159            ),
2160            Column("data", String(30)),
2161        )
2162        Table(
2163            "someothertable",
2164            metadata,
2165            Column(
2166                "col1",
2167                Integer,
2168                primary_key=True,
2169                test_needs_autoincrement=True,
2170            ),
2171            Column("scol1", Integer, ForeignKey("sometable.col1")),
2172            Column("data", String(20)),
2173        )
2174
2175    def test_basic(self):
2176        someothertable, sometable = (
2177            self.tables.someothertable,
2178            self.tables.sometable,
2179        )
2180
2181        class MyList(list):
2182            pass
2183
2184        class Foo(object):
2185            pass
2186
2187        class Bar(object):
2188            pass
2189
2190        self.mapper_registry.map_imperatively(
2191            Foo,
2192            sometable,
2193            properties={"bars": relationship(Bar, collection_class=MyList)},
2194        )
2195        self.mapper_registry.map_imperatively(Bar, someothertable)
2196        f = Foo()
2197        assert isinstance(f.bars, MyList)
2198
2199    def test_lazyload(self):
2200        """test that a 'set' can be used as a collection and can lazyload."""
2201
2202        someothertable, sometable = (
2203            self.tables.someothertable,
2204            self.tables.sometable,
2205        )
2206
2207        class Foo(object):
2208            pass
2209
2210        class Bar(object):
2211            pass
2212
2213        self.mapper_registry.map_imperatively(
2214            Foo,
2215            sometable,
2216            properties={"bars": relationship(Bar, collection_class=set)},
2217        )
2218        self.mapper_registry.map_imperatively(Bar, someothertable)
2219        f = Foo()
2220        f.bars.add(Bar())
2221        f.bars.add(Bar())
2222        sess = fixture_session()
2223        sess.add(f)
2224        sess.flush()
2225        sess.expunge_all()
2226        f = sess.get(Foo, f.col1)
2227        assert len(list(f.bars)) == 2
2228        f.bars.clear()
2229
2230    def test_dict(self):
2231        """test that a 'dict' can be used as a collection and can lazyload."""
2232
2233        someothertable, sometable = (
2234            self.tables.someothertable,
2235            self.tables.sometable,
2236        )
2237
2238        class Foo(object):
2239            pass
2240
2241        class Bar(object):
2242            pass
2243
2244        class AppenderDict(dict):
2245            @collection.appender
2246            def set(self, item):
2247                self[id(item)] = item
2248
2249            @collection.remover
2250            def remove(self, item):
2251                if id(item) in self:
2252                    del self[id(item)]
2253
2254        self.mapper_registry.map_imperatively(
2255            Foo,
2256            sometable,
2257            properties={
2258                "bars": relationship(Bar, collection_class=AppenderDict)
2259            },
2260        )
2261        self.mapper_registry.map_imperatively(Bar, someothertable)
2262        f = Foo()
2263        f.bars.set(Bar())
2264        f.bars.set(Bar())
2265        sess = fixture_session()
2266        sess.add(f)
2267        sess.flush()
2268        sess.expunge_all()
2269        f = sess.get(Foo, f.col1)
2270        assert len(list(f.bars)) == 2
2271        f.bars.clear()
2272
2273    def test_dict_wrapper(self):
2274        """test that the supplied 'dict' wrapper can be used as a
2275        collection and can lazyload."""
2276
2277        someothertable, sometable = (
2278            self.tables.someothertable,
2279            self.tables.sometable,
2280        )
2281
2282        class Foo(object):
2283            pass
2284
2285        class Bar(object):
2286            def __init__(self, data):
2287                self.data = data
2288
2289        self.mapper_registry.map_imperatively(
2290            Foo,
2291            sometable,
2292            properties={
2293                "bars": relationship(
2294                    Bar,
2295                    collection_class=collections.column_mapped_collection(
2296                        someothertable.c.data
2297                    ),
2298                )
2299            },
2300        )
2301        self.mapper_registry.map_imperatively(Bar, someothertable)
2302
2303        f = Foo()
2304        col = collections.collection_adapter(f.bars)
2305        col.append_with_event(Bar("a"))
2306        col.append_with_event(Bar("b"))
2307        sess = fixture_session()
2308        sess.add(f)
2309        sess.flush()
2310        sess.expunge_all()
2311        f = sess.get(Foo, f.col1)
2312        assert len(list(f.bars)) == 2
2313
2314        strongref = list(f.bars.values())
2315        existing = set([id(b) for b in strongref])
2316
2317        col = collections.collection_adapter(f.bars)
2318        col.append_with_event(Bar("b"))
2319        f.bars["a"] = Bar("a")
2320        sess.flush()
2321        sess.expunge_all()
2322        f = sess.get(Foo, f.col1)
2323        assert len(list(f.bars)) == 2
2324
2325        replaced = set([id(b) for b in list(f.bars.values())])
2326        ne_(existing, replaced)
2327
2328    def test_list(self):
2329        self._test_list(list)
2330
2331    def test_list_no_setslice(self):
2332        class ListLike(object):
2333            def __init__(self):
2334                self.data = list()
2335
2336            def append(self, item):
2337                self.data.append(item)
2338
2339            def remove(self, item):
2340                self.data.remove(item)
2341
2342            def insert(self, index, item):
2343                self.data.insert(index, item)
2344
2345            def pop(self, index=-1):
2346                return self.data.pop(index)
2347
2348            def extend(self):
2349                assert False
2350
2351            def __len__(self):
2352                return len(self.data)
2353
2354            def __setitem__(self, key, value):
2355                self.data[key] = value
2356
2357            def __getitem__(self, key):
2358                return self.data[key]
2359
2360            def __delitem__(self, key):
2361                del self.data[key]
2362
2363            def __iter__(self):
2364                return iter(self.data)
2365
2366            __hash__ = object.__hash__
2367
2368            def __eq__(self, other):
2369                return self.data == other
2370
2371            def __repr__(self):
2372                return "ListLike(%s)" % repr(self.data)
2373
2374        self._test_list(ListLike)
2375
2376    def _test_list(self, listcls):
2377        someothertable, sometable = (
2378            self.tables.someothertable,
2379            self.tables.sometable,
2380        )
2381
2382        class Parent(object):
2383            pass
2384
2385        class Child(object):
2386            pass
2387
2388        self.mapper_registry.map_imperatively(
2389            Parent,
2390            sometable,
2391            properties={
2392                "children": relationship(Child, collection_class=listcls)
2393            },
2394        )
2395        self.mapper_registry.map_imperatively(Child, someothertable)
2396
2397        control = list()
2398        p = Parent()
2399
2400        o = Child()
2401        control.append(o)
2402        p.children.append(o)
2403        assert control == p.children
2404        assert control == list(p.children)
2405
2406        o = [Child(), Child(), Child(), Child()]
2407        control.extend(o)
2408        p.children.extend(o)
2409        assert control == p.children
2410        assert control == list(p.children)
2411
2412        assert control[0] == p.children[0]
2413        assert control[-1] == p.children[-1]
2414        assert control[1:3] == p.children[1:3]
2415
2416        del control[1]
2417        del p.children[1]
2418        assert control == p.children
2419        assert control == list(p.children)
2420
2421        o = [Child()]
2422        control[1:3] = o
2423
2424        p.children[1:3] = o
2425        assert control == p.children
2426        assert control == list(p.children)
2427
2428        o = [Child(), Child(), Child(), Child()]
2429        control[1:3] = o
2430        p.children[1:3] = o
2431        assert control == p.children
2432        assert control == list(p.children)
2433
2434        o = [Child(), Child(), Child(), Child()]
2435        control[-1:-2] = o
2436        p.children[-1:-2] = o
2437        assert control == p.children
2438        assert control == list(p.children)
2439
2440        o = [Child(), Child(), Child(), Child()]
2441        control[4:] = o
2442        p.children[4:] = o
2443        assert control == p.children
2444        assert control == list(p.children)
2445
2446        o = Child()
2447        control.insert(0, o)
2448        p.children.insert(0, o)
2449        assert control == p.children
2450        assert control == list(p.children)
2451
2452        o = Child()
2453        control.insert(3, o)
2454        p.children.insert(3, o)
2455        assert control == p.children
2456        assert control == list(p.children)
2457
2458        o = Child()
2459        control.insert(999, o)
2460        p.children.insert(999, o)
2461        assert control == p.children
2462        assert control == list(p.children)
2463
2464        del control[0:1]
2465        del p.children[0:1]
2466        assert control == p.children
2467        assert control == list(p.children)
2468
2469        del control[1:1]
2470        del p.children[1:1]
2471        assert control == p.children
2472        assert control == list(p.children)
2473
2474        del control[1:3]
2475        del p.children[1:3]
2476        assert control == p.children
2477        assert control == list(p.children)
2478
2479        del control[7:]
2480        del p.children[7:]
2481        assert control == p.children
2482        assert control == list(p.children)
2483
2484        assert control.pop() == p.children.pop()
2485        assert control == p.children
2486        assert control == list(p.children)
2487
2488        assert control.pop(0) == p.children.pop(0)
2489        assert control == p.children
2490        assert control == list(p.children)
2491
2492        assert control.pop(2) == p.children.pop(2)
2493        assert control == p.children
2494        assert control == list(p.children)
2495
2496        o = Child()
2497        control.insert(2, o)
2498        p.children.insert(2, o)
2499        assert control == p.children
2500        assert control == list(p.children)
2501
2502        control.remove(o)
2503        p.children.remove(o)
2504        assert control == p.children
2505        assert control == list(p.children)
2506
2507    def test_custom(self):
2508        someothertable, sometable = (
2509            self.tables.someothertable,
2510            self.tables.sometable,
2511        )
2512
2513        class Parent(object):
2514            pass
2515
2516        class Child(object):
2517            pass
2518
2519        class MyCollection(object):
2520            def __init__(self):
2521                self.data = []
2522
2523            @collection.appender
2524            def append(self, value):
2525                self.data.append(value)
2526
2527            @collection.remover
2528            def remove(self, value):
2529                self.data.remove(value)
2530
2531            @collection.iterator
2532            def __iter__(self):
2533                return iter(self.data)
2534
2535        self.mapper_registry.map_imperatively(
2536            Parent,
2537            sometable,
2538            properties={
2539                "children": relationship(Child, collection_class=MyCollection)
2540            },
2541        )
2542        self.mapper_registry.map_imperatively(Child, someothertable)
2543
2544        control = list()
2545        p1 = Parent()
2546
2547        o = Child()
2548        control.append(o)
2549        p1.children.append(o)
2550        assert control == list(p1.children)
2551
2552        o = Child()
2553        control.append(o)
2554        p1.children.append(o)
2555        assert control == list(p1.children)
2556
2557        o = Child()
2558        control.append(o)
2559        p1.children.append(o)
2560        assert control == list(p1.children)
2561
2562        sess = fixture_session()
2563        sess.add(p1)
2564        sess.flush()
2565        sess.expunge_all()
2566
2567        p2 = sess.get(Parent, p1.col1)
2568        o = list(p2.children)
2569        assert len(o) == 3
2570
2571
2572class InstrumentationTest(fixtures.ORMTest):
2573    def test_uncooperative_descriptor_in_sweep(self):
2574        class DoNotTouch(object):
2575            def __get__(self, obj, owner):
2576                raise AttributeError
2577
2578        class Touchy(list):
2579            no_touch = DoNotTouch()
2580
2581        assert "no_touch" in Touchy.__dict__
2582        assert not hasattr(Touchy, "no_touch")
2583        assert "no_touch" in dir(Touchy)
2584
2585        collections._instrument_class(Touchy)
2586
2587    def test_referenced_by_owner(self):
2588        class Foo(object):
2589            pass
2590
2591        instrumentation.register_class(Foo)
2592        attributes.register_attribute(
2593            Foo, "attr", uselist=True, useobject=True
2594        )
2595
2596        f1 = Foo()
2597        f1.attr.append(3)
2598
2599        adapter = collections.collection_adapter(f1.attr)
2600        assert adapter._referenced_by_owner
2601
2602        f1.attr = []
2603        assert not adapter._referenced_by_owner
2604