1import contextlib
2from operator import and_
3
4from sqlalchemy import event
5from sqlalchemy import exc as sa_exc
6from sqlalchemy import ForeignKey
7from sqlalchemy import Integer
8from sqlalchemy import String
9from sqlalchemy import text
10from sqlalchemy import util
11from sqlalchemy.orm import attributes
12from sqlalchemy.orm import create_session
13from sqlalchemy.orm import instrumentation
14from sqlalchemy.orm import mapper
15from sqlalchemy.orm import relationship
16import sqlalchemy.orm.collections as collections
17from sqlalchemy.orm.collections import collection
18from sqlalchemy.testing import assert_raises
19from sqlalchemy.testing import assert_raises_message
20from sqlalchemy.testing import eq_
21from sqlalchemy.testing import fixtures
22from sqlalchemy.testing import ne_
23from sqlalchemy.testing.schema import Column
24from sqlalchemy.testing.schema import Table
25
26
27class Canary(object):
28    def __init__(self):
29        self.data = set()
30        self.added = set()
31        self.removed = set()
32        self.dupe_check = True
33
34    @contextlib.contextmanager
35    def defer_dupe_check(self):
36        self.dupe_check = False
37        try:
38            yield
39        finally:
40            self.dupe_check = True
41
42    def listen(self, attr):
43        event.listen(attr, "append", self.append)
44        event.listen(attr, "remove", self.remove)
45        event.listen(attr, "set", self.set)
46
47    def append(self, obj, value, initiator):
48        if self.dupe_check:
49            assert value not in self.added
50            self.added.add(value)
51        self.data.add(value)
52        return value
53
54    def remove(self, obj, value, initiator):
55        if self.dupe_check:
56            assert value not in self.removed
57            self.removed.add(value)
58        self.data.remove(value)
59
60    def set(self, obj, value, oldvalue, initiator):
61        if isinstance(value, str):
62            value = CollectionsTest.entity_maker()
63
64        if oldvalue is not None:
65            self.remove(obj, oldvalue, None)
66        self.append(obj, value, None)
67        return value
68
69
70class CollectionsTest(fixtures.ORMTest):
71    class Entity(object):
72        def __init__(self, a=None, b=None, c=None):
73            self.a = a
74            self.b = b
75            self.c = c
76
77        def __repr__(self):
78            return str((id(self), self.a, self.b, self.c))
79
80    @classmethod
81    def setup_class(cls):
82        instrumentation.register_class(cls.Entity)
83
84    @classmethod
85    def teardown_class(cls):
86        instrumentation.unregister_class(cls.Entity)
87        super(CollectionsTest, cls).teardown_class()
88
89    _entity_id = 1
90
91    @classmethod
92    def entity_maker(cls):
93        cls._entity_id += 1
94        return cls.Entity(cls._entity_id)
95
96    @classmethod
97    def dictable_entity(cls, a=None, b=None, c=None):
98        id_ = cls._entity_id = cls._entity_id + 1
99        return cls.Entity(a or str(id_), b or "value %s" % id, c)
100
101    def _test_adapter(self, typecallable, creator=None, to_set=None):
102        if creator is None:
103            creator = self.entity_maker
104
105        class Foo(object):
106            pass
107
108        canary = Canary()
109        instrumentation.register_class(Foo)
110        d = attributes.register_attribute(
111            Foo,
112            "attr",
113            uselist=True,
114            typecallable=typecallable,
115            useobject=True,
116        )
117        canary.listen(d)
118
119        obj = Foo()
120        adapter = collections.collection_adapter(obj.attr)
121        direct = obj.attr
122        if to_set is None:
123
124            def to_set(col):
125                return set(col)
126
127        def assert_eq():
128            self.assert_(to_set(direct) == canary.data)
129            self.assert_(set(adapter) == canary.data)
130
131        def assert_ne():
132            self.assert_(to_set(direct) != canary.data)
133
134        e1, e2 = creator(), creator()
135
136        adapter.append_with_event(e1)
137        assert_eq()
138
139        adapter.append_without_event(e2)
140        assert_ne()
141        canary.data.add(e2)
142        assert_eq()
143
144        adapter.remove_without_event(e2)
145        assert_ne()
146        canary.data.remove(e2)
147        assert_eq()
148
149        adapter.remove_with_event(e1)
150        assert_eq()
151
152    def _test_list(self, typecallable, creator=None):
153        if creator is None:
154            creator = self.entity_maker
155
156        class Foo(object):
157            pass
158
159        canary = Canary()
160        instrumentation.register_class(Foo)
161        d = attributes.register_attribute(
162            Foo,
163            "attr",
164            uselist=True,
165            typecallable=typecallable,
166            useobject=True,
167        )
168        canary.listen(d)
169
170        obj = Foo()
171        adapter = collections.collection_adapter(obj.attr)
172        direct = obj.attr
173        control = list()
174
175        def assert_eq():
176            eq_(set(direct), canary.data)
177            eq_(set(adapter), canary.data)
178            eq_(direct, control)
179
180        # assume append() is available for list tests
181        e = creator()
182        direct.append(e)
183        control.append(e)
184        assert_eq()
185
186        if hasattr(direct, "pop"):
187            direct.pop()
188            control.pop()
189            assert_eq()
190
191        if hasattr(direct, "__setitem__"):
192            e = creator()
193            direct.append(e)
194            control.append(e)
195
196            e = creator()
197            direct[0] = e
198            control[0] = e
199            assert_eq()
200
201            if util.reduce(
202                and_,
203                [
204                    hasattr(direct, a)
205                    for a in ("__delitem__", "insert", "__len__")
206                ],
207                True,
208            ):
209                values = [creator(), creator(), creator(), creator()]
210                direct[slice(0, 1)] = values
211                control[slice(0, 1)] = values
212                assert_eq()
213
214                values = [creator(), creator()]
215                direct[slice(0, -1, 2)] = values
216                control[slice(0, -1, 2)] = values
217                assert_eq()
218
219                values = [creator()]
220                direct[slice(0, -1)] = values
221                control[slice(0, -1)] = values
222                assert_eq()
223
224                values = [creator(), creator(), creator()]
225                control[:] = values
226                direct[:] = values
227
228                def invalid():
229                    direct[slice(0, 6, 2)] = [creator()]
230
231                assert_raises(ValueError, invalid)
232
233        if hasattr(direct, "__delitem__"):
234            e = creator()
235            direct.append(e)
236            control.append(e)
237            del direct[-1]
238            del control[-1]
239            assert_eq()
240
241            if hasattr(direct, "__getslice__"):
242                for e in [creator(), creator(), creator(), creator()]:
243                    direct.append(e)
244                    control.append(e)
245
246                del direct[:-3]
247                del control[:-3]
248                assert_eq()
249
250                del direct[0:1]
251                del control[0:1]
252                assert_eq()
253
254                del direct[::2]
255                del control[::2]
256                assert_eq()
257
258        if hasattr(direct, "remove"):
259            e = creator()
260            direct.append(e)
261            control.append(e)
262
263            direct.remove(e)
264            control.remove(e)
265            assert_eq()
266
267        if hasattr(direct, "__setitem__") or hasattr(direct, "__setslice__"):
268
269            values = [creator(), creator()]
270            direct[:] = values
271            control[:] = values
272            assert_eq()
273
274            # test slice assignment where we slice assign to self,
275            # currently a no-op, issue #4990
276            # note that in py2k, the bug does not exist but it recreates
277            # the collection which breaks our fixtures here
278            with canary.defer_dupe_check():
279                direct[:] = direct
280                control[:] = control
281            assert_eq()
282
283            # we dont handle assignment of self to slices, as this
284            # implies duplicate entries.  behavior here is not well defined
285            # and perhaps should emit a warning
286            # direct[0:1] = list(direct)
287            # control[0:1] = list(control)
288            # assert_eq()
289
290            # test slice assignment where
291            # slice size goes over the number of items
292            values = [creator(), creator()]
293            direct[1:3] = values
294            control[1:3] = values
295            assert_eq()
296
297            values = [creator(), creator()]
298            direct[0:1] = values
299            control[0:1] = values
300            assert_eq()
301
302            values = [creator()]
303            direct[0:] = values
304            control[0:] = values
305            assert_eq()
306
307            values = [creator()]
308            direct[:1] = values
309            control[:1] = values
310            assert_eq()
311
312            values = [creator()]
313            direct[-1::2] = values
314            control[-1::2] = values
315            assert_eq()
316
317            values = [creator()] * len(direct[1::2])
318            direct[1::2] = values
319            control[1::2] = values
320            assert_eq()
321
322            values = [creator(), creator()]
323            direct[-1:-3] = values
324            control[-1:-3] = values
325            assert_eq()
326
327            values = [creator(), creator()]
328            direct[-2:-1] = values
329            control[-2:-1] = values
330            assert_eq()
331
332            values = [creator()]
333            direct[0:0] = values
334            control[0:0] = values
335            assert_eq()
336
337        if hasattr(direct, "__delitem__") or hasattr(direct, "__delslice__"):
338            for i in range(1, 4):
339                e = creator()
340                direct.append(e)
341                control.append(e)
342
343            del direct[-1:]
344            del control[-1:]
345            assert_eq()
346
347            del direct[1:2]
348            del control[1:2]
349            assert_eq()
350
351            del direct[:]
352            del control[:]
353            assert_eq()
354
355        if hasattr(direct, "clear"):
356            for i in range(1, 4):
357                e = creator()
358                direct.append(e)
359                control.append(e)
360
361            direct.clear()
362            control.clear()
363            assert_eq()
364
365        if hasattr(direct, "extend"):
366            values = [creator(), creator(), creator()]
367
368            direct.extend(values)
369            control.extend(values)
370            assert_eq()
371
372        if hasattr(direct, "__iadd__"):
373            values = [creator(), creator(), creator()]
374
375            direct += values
376            control += values
377            assert_eq()
378
379            direct += []
380            control += []
381            assert_eq()
382
383            values = [creator(), creator()]
384            obj.attr += values
385            control += values
386            assert_eq()
387
388        if hasattr(direct, "__imul__"):
389            direct *= 2
390            control *= 2
391            assert_eq()
392
393            obj.attr *= 2
394            control *= 2
395            assert_eq()
396
397    def _test_list_bulk(self, typecallable, creator=None):
398        if creator is None:
399            creator = self.entity_maker
400
401        class Foo(object):
402            pass
403
404        canary = Canary()
405        instrumentation.register_class(Foo)
406        d = attributes.register_attribute(
407            Foo,
408            "attr",
409            uselist=True,
410            typecallable=typecallable,
411            useobject=True,
412        )
413        canary.listen(d)
414
415        obj = Foo()
416        direct = obj.attr
417
418        e1 = creator()
419        obj.attr.append(e1)
420
421        like_me = typecallable()
422        e2 = creator()
423        like_me.append(e2)
424
425        self.assert_(obj.attr is direct)
426        obj.attr = like_me
427        self.assert_(obj.attr is not direct)
428        self.assert_(obj.attr is not like_me)
429        self.assert_(set(obj.attr) == set([e2]))
430        self.assert_(e1 in canary.removed)
431        self.assert_(e2 in canary.added)
432
433        e3 = creator()
434        real_list = [e3]
435        obj.attr = real_list
436        self.assert_(obj.attr is not real_list)
437        self.assert_(set(obj.attr) == set([e3]))
438        self.assert_(e2 in canary.removed)
439        self.assert_(e3 in canary.added)
440
441        e4 = creator()
442        try:
443            obj.attr = set([e4])
444            self.assert_(False)
445        except TypeError:
446            self.assert_(e4 not in canary.data)
447            self.assert_(e3 in canary.data)
448
449        e5 = creator()
450        e6 = creator()
451        e7 = creator()
452        obj.attr = [e5, e6, e7]
453        self.assert_(e5 in canary.added)
454        self.assert_(e6 in canary.added)
455        self.assert_(e7 in canary.added)
456
457        obj.attr = [e6, e7]
458        self.assert_(e5 in canary.removed)
459        self.assert_(e6 in canary.added)
460        self.assert_(e7 in canary.added)
461        self.assert_(e6 not in canary.removed)
462        self.assert_(e7 not in canary.removed)
463
464    def test_list(self):
465        self._test_adapter(list)
466        self._test_list(list)
467        self._test_list_bulk(list)
468
469    def test_list_setitem_with_slices(self):
470
471        # this is a "list" that has no __setslice__
472        # or __delslice__ methods.  The __setitem__
473        # and __delitem__ must therefore accept
474        # slice objects (i.e. as in py3k)
475        class ListLike(object):
476            def __init__(self):
477                self.data = list()
478
479            def append(self, item):
480                self.data.append(item)
481
482            def remove(self, item):
483                self.data.remove(item)
484
485            def insert(self, index, item):
486                self.data.insert(index, item)
487
488            def pop(self, index=-1):
489                return self.data.pop(index)
490
491            def extend(self):
492                assert False
493
494            def __len__(self):
495                return len(self.data)
496
497            def __setitem__(self, key, value):
498                self.data[key] = value
499
500            def __getitem__(self, key):
501                return self.data[key]
502
503            def __delitem__(self, key):
504                del self.data[key]
505
506            def __iter__(self):
507                return iter(self.data)
508
509            __hash__ = object.__hash__
510
511            def __eq__(self, other):
512                return self.data == other
513
514            def __repr__(self):
515                return "ListLike(%s)" % repr(self.data)
516
517        self._test_adapter(ListLike)
518        self._test_list(ListLike)
519        self._test_list_bulk(ListLike)
520
521    def test_list_subclass(self):
522        class MyList(list):
523            pass
524
525        self._test_adapter(MyList)
526        self._test_list(MyList)
527        self._test_list_bulk(MyList)
528        self.assert_(getattr(MyList, "_sa_instrumented") == id(MyList))
529
530    def test_list_duck(self):
531        class ListLike(object):
532            def __init__(self):
533                self.data = list()
534
535            def append(self, item):
536                self.data.append(item)
537
538            def remove(self, item):
539                self.data.remove(item)
540
541            def insert(self, index, item):
542                self.data.insert(index, item)
543
544            def pop(self, index=-1):
545                return self.data.pop(index)
546
547            def extend(self):
548                assert False
549
550            def __iter__(self):
551                return iter(self.data)
552
553            __hash__ = object.__hash__
554
555            def __eq__(self, other):
556                return self.data == other
557
558            def __repr__(self):
559                return "ListLike(%s)" % repr(self.data)
560
561        self._test_adapter(ListLike)
562        self._test_list(ListLike)
563        self._test_list_bulk(ListLike)
564        self.assert_(getattr(ListLike, "_sa_instrumented") == id(ListLike))
565
566    def test_list_emulates(self):
567        class ListIsh(object):
568            __emulates__ = list
569
570            def __init__(self):
571                self.data = list()
572
573            def append(self, item):
574                self.data.append(item)
575
576            def remove(self, item):
577                self.data.remove(item)
578
579            def insert(self, index, item):
580                self.data.insert(index, item)
581
582            def pop(self, index=-1):
583                return self.data.pop(index)
584
585            def extend(self):
586                assert False
587
588            def __iter__(self):
589                return iter(self.data)
590
591            __hash__ = object.__hash__
592
593            def __eq__(self, other):
594                return self.data == other
595
596            def __repr__(self):
597                return "ListIsh(%s)" % repr(self.data)
598
599        self._test_adapter(ListIsh)
600        self._test_list(ListIsh)
601        self._test_list_bulk(ListIsh)
602        self.assert_(getattr(ListIsh, "_sa_instrumented") == id(ListIsh))
603
604    def _test_set(self, typecallable, creator=None):
605        if creator is None:
606            creator = self.entity_maker
607
608        class Foo(object):
609            pass
610
611        canary = Canary()
612        instrumentation.register_class(Foo)
613        d = attributes.register_attribute(
614            Foo,
615            "attr",
616            uselist=True,
617            typecallable=typecallable,
618            useobject=True,
619        )
620        canary.listen(d)
621
622        obj = Foo()
623        adapter = collections.collection_adapter(obj.attr)
624        direct = obj.attr
625        control = set()
626
627        def assert_eq():
628            eq_(set(direct), canary.data)
629            eq_(set(adapter), canary.data)
630            eq_(direct, control)
631
632        def addall(*values):
633            for item in values:
634                direct.add(item)
635                control.add(item)
636            assert_eq()
637
638        def zap():
639            for item in list(direct):
640                direct.remove(item)
641            control.clear()
642
643        addall(creator())
644
645        e = creator()
646        addall(e)
647        addall(e)
648
649        if hasattr(direct, "remove"):
650            e = creator()
651            addall(e)
652
653            direct.remove(e)
654            control.remove(e)
655            assert_eq()
656
657            e = creator()
658            try:
659                direct.remove(e)
660            except KeyError:
661                assert_eq()
662                self.assert_(e not in canary.removed)
663            else:
664                self.assert_(False)
665
666        if hasattr(direct, "discard"):
667            e = creator()
668            addall(e)
669
670            direct.discard(e)
671            control.discard(e)
672            assert_eq()
673
674            e = creator()
675            direct.discard(e)
676            self.assert_(e not in canary.removed)
677            assert_eq()
678
679        if hasattr(direct, "update"):
680            zap()
681            e = creator()
682            addall(e)
683
684            values = set([e, creator(), creator()])
685
686            direct.update(values)
687            control.update(values)
688            assert_eq()
689
690        if hasattr(direct, "__ior__"):
691            zap()
692            e = creator()
693            addall(e)
694
695            values = set([e, creator(), creator()])
696
697            direct |= values
698            control |= values
699            assert_eq()
700
701            # cover self-assignment short-circuit
702            values = set([e, creator(), creator()])
703            obj.attr |= values
704            control |= values
705            assert_eq()
706
707            values = frozenset([e, creator()])
708            obj.attr |= values
709            control |= values
710            assert_eq()
711
712            try:
713                direct |= [e, creator()]
714                assert False
715            except TypeError:
716                assert True
717
718        addall(creator(), creator())
719        direct.clear()
720        control.clear()
721        assert_eq()
722
723        # note: the clear test previously needs
724        # to have executed in order for this to
725        # pass in all cases; else there's the possibility
726        # of non-deterministic behavior.
727        addall(creator())
728        direct.pop()
729        control.pop()
730        assert_eq()
731
732        if hasattr(direct, "difference_update"):
733            zap()
734            e = creator()
735            addall(creator(), creator())
736            values = set([creator()])
737
738            direct.difference_update(values)
739            control.difference_update(values)
740            assert_eq()
741            values.update(set([e, creator()]))
742            direct.difference_update(values)
743            control.difference_update(values)
744            assert_eq()
745
746        if hasattr(direct, "__isub__"):
747            zap()
748            e = creator()
749            addall(creator(), creator())
750            values = set([creator()])
751
752            direct -= values
753            control -= values
754            assert_eq()
755            values.update(set([e, creator()]))
756            direct -= values
757            control -= values
758            assert_eq()
759
760            values = set([creator()])
761            obj.attr -= values
762            control -= values
763            assert_eq()
764
765            values = frozenset([creator()])
766            obj.attr -= values
767            control -= values
768            assert_eq()
769
770            try:
771                direct -= [e, creator()]
772                assert False
773            except TypeError:
774                assert True
775
776        if hasattr(direct, "intersection_update"):
777            zap()
778            e = creator()
779            addall(e, creator(), creator())
780            values = set(control)
781
782            direct.intersection_update(values)
783            control.intersection_update(values)
784            assert_eq()
785
786            values.update(set([e, creator()]))
787            direct.intersection_update(values)
788            control.intersection_update(values)
789            assert_eq()
790
791        if hasattr(direct, "__iand__"):
792            zap()
793            e = creator()
794            addall(e, creator(), creator())
795            values = set(control)
796
797            direct &= values
798            control &= values
799            assert_eq()
800
801            values.update(set([e, creator()]))
802            direct &= values
803            control &= values
804            assert_eq()
805
806            values.update(set([creator()]))
807            obj.attr &= values
808            control &= values
809            assert_eq()
810
811            try:
812                direct &= [e, creator()]
813                assert False
814            except TypeError:
815                assert True
816
817        if hasattr(direct, "symmetric_difference_update"):
818            zap()
819            e = creator()
820            addall(e, creator(), creator())
821
822            values = set([e, creator()])
823            direct.symmetric_difference_update(values)
824            control.symmetric_difference_update(values)
825            assert_eq()
826
827            e = creator()
828            addall(e)
829            values = set([e])
830            direct.symmetric_difference_update(values)
831            control.symmetric_difference_update(values)
832            assert_eq()
833
834            values = set()
835            direct.symmetric_difference_update(values)
836            control.symmetric_difference_update(values)
837            assert_eq()
838
839        if hasattr(direct, "__ixor__"):
840            zap()
841            e = creator()
842            addall(e, creator(), creator())
843
844            values = set([e, creator()])
845            direct ^= values
846            control ^= values
847            assert_eq()
848
849            e = creator()
850            addall(e)
851            values = set([e])
852            direct ^= values
853            control ^= values
854            assert_eq()
855
856            values = set()
857            direct ^= values
858            control ^= values
859            assert_eq()
860
861            values = set([creator()])
862            obj.attr ^= values
863            control ^= values
864            assert_eq()
865
866            try:
867                direct ^= [e, creator()]
868                assert False
869            except TypeError:
870                assert True
871
872    def _test_set_bulk(self, typecallable, creator=None):
873        if creator is None:
874            creator = self.entity_maker
875
876        class Foo(object):
877            pass
878
879        canary = Canary()
880        instrumentation.register_class(Foo)
881        d = attributes.register_attribute(
882            Foo,
883            "attr",
884            uselist=True,
885            typecallable=typecallable,
886            useobject=True,
887        )
888        canary.listen(d)
889
890        obj = Foo()
891        direct = obj.attr
892
893        e1 = creator()
894        obj.attr.add(e1)
895
896        like_me = typecallable()
897        e2 = creator()
898        like_me.add(e2)
899
900        self.assert_(obj.attr is direct)
901        obj.attr = like_me
902        self.assert_(obj.attr is not direct)
903        self.assert_(obj.attr is not like_me)
904        self.assert_(obj.attr == set([e2]))
905        self.assert_(e1 in canary.removed)
906        self.assert_(e2 in canary.added)
907
908        e3 = creator()
909        real_set = set([e3])
910        obj.attr = real_set
911        self.assert_(obj.attr is not real_set)
912        self.assert_(obj.attr == set([e3]))
913        self.assert_(e2 in canary.removed)
914        self.assert_(e3 in canary.added)
915
916        e4 = creator()
917        try:
918            obj.attr = [e4]
919            self.assert_(False)
920        except TypeError:
921            self.assert_(e4 not in canary.data)
922            self.assert_(e3 in canary.data)
923
924    def test_set(self):
925        self._test_adapter(set)
926        self._test_set(set)
927        self._test_set_bulk(set)
928
929    def test_set_subclass(self):
930        class MySet(set):
931            pass
932
933        self._test_adapter(MySet)
934        self._test_set(MySet)
935        self._test_set_bulk(MySet)
936        self.assert_(getattr(MySet, "_sa_instrumented") == id(MySet))
937
938    def test_set_duck(self):
939        class SetLike(object):
940            def __init__(self):
941                self.data = set()
942
943            def add(self, item):
944                self.data.add(item)
945
946            def remove(self, item):
947                self.data.remove(item)
948
949            def discard(self, item):
950                self.data.discard(item)
951
952            def clear(self):
953                self.data.clear()
954
955            def pop(self):
956                return self.data.pop()
957
958            def update(self, other):
959                self.data.update(other)
960
961            def __iter__(self):
962                return iter(self.data)
963
964            __hash__ = object.__hash__
965
966            def __eq__(self, other):
967                return self.data == other
968
969        self._test_adapter(SetLike)
970        self._test_set(SetLike)
971        self._test_set_bulk(SetLike)
972        self.assert_(getattr(SetLike, "_sa_instrumented") == id(SetLike))
973
974    def test_set_emulates(self):
975        class SetIsh(object):
976            __emulates__ = set
977
978            def __init__(self):
979                self.data = set()
980
981            def add(self, item):
982                self.data.add(item)
983
984            def remove(self, item):
985                self.data.remove(item)
986
987            def discard(self, item):
988                self.data.discard(item)
989
990            def pop(self):
991                return self.data.pop()
992
993            def update(self, other):
994                self.data.update(other)
995
996            def __iter__(self):
997                return iter(self.data)
998
999            def clear(self):
1000                self.data.clear()
1001
1002            __hash__ = object.__hash__
1003
1004            def __eq__(self, other):
1005                return self.data == other
1006
1007        self._test_adapter(SetIsh)
1008        self._test_set(SetIsh)
1009        self._test_set_bulk(SetIsh)
1010        self.assert_(getattr(SetIsh, "_sa_instrumented") == id(SetIsh))
1011
1012    def _test_dict(self, typecallable, creator=None):
1013        if creator is None:
1014            creator = self.dictable_entity
1015
1016        class Foo(object):
1017            pass
1018
1019        canary = Canary()
1020        instrumentation.register_class(Foo)
1021        d = attributes.register_attribute(
1022            Foo,
1023            "attr",
1024            uselist=True,
1025            typecallable=typecallable,
1026            useobject=True,
1027        )
1028        canary.listen(d)
1029
1030        obj = Foo()
1031        adapter = collections.collection_adapter(obj.attr)
1032        direct = obj.attr
1033        control = dict()
1034
1035        def assert_eq():
1036            self.assert_(set(direct.values()) == canary.data)
1037            self.assert_(set(adapter) == canary.data)
1038            self.assert_(direct == control)
1039
1040        def addall(*values):
1041            for item in values:
1042                direct.set(item)
1043                control[item.a] = item
1044            assert_eq()
1045
1046        def zap():
1047            for item in list(adapter):
1048                direct.remove(item)
1049            control.clear()
1050
1051        # assume an 'set' method is available for tests
1052        addall(creator())
1053
1054        if hasattr(direct, "__setitem__"):
1055            e = creator()
1056            direct[e.a] = e
1057            control[e.a] = e
1058            assert_eq()
1059
1060            e = creator(e.a, e.b)
1061            direct[e.a] = e
1062            control[e.a] = e
1063            assert_eq()
1064
1065        if hasattr(direct, "__delitem__"):
1066            e = creator()
1067            addall(e)
1068
1069            del direct[e.a]
1070            del control[e.a]
1071            assert_eq()
1072
1073            e = creator()
1074            try:
1075                del direct[e.a]
1076            except KeyError:
1077                self.assert_(e not in canary.removed)
1078
1079        if hasattr(direct, "clear"):
1080            addall(creator(), creator(), creator())
1081
1082            direct.clear()
1083            control.clear()
1084            assert_eq()
1085
1086            direct.clear()
1087            control.clear()
1088            assert_eq()
1089
1090        if hasattr(direct, "pop"):
1091            e = creator()
1092            addall(e)
1093
1094            direct.pop(e.a)
1095            control.pop(e.a)
1096            assert_eq()
1097
1098            e = creator()
1099            try:
1100                direct.pop(e.a)
1101            except KeyError:
1102                self.assert_(e not in canary.removed)
1103
1104        if hasattr(direct, "popitem"):
1105            zap()
1106            e = creator()
1107            addall(e)
1108
1109            direct.popitem()
1110            control.popitem()
1111            assert_eq()
1112
1113        if hasattr(direct, "setdefault"):
1114            e = creator()
1115
1116            val_a = direct.setdefault(e.a, e)
1117            val_b = control.setdefault(e.a, e)
1118            assert_eq()
1119            self.assert_(val_a is val_b)
1120
1121            val_a = direct.setdefault(e.a, e)
1122            val_b = control.setdefault(e.a, e)
1123            assert_eq()
1124            self.assert_(val_a is val_b)
1125
1126        if hasattr(direct, "update"):
1127            e = creator()
1128            d = dict([(ee.a, ee) for ee in [e, creator(), creator()]])
1129            addall(e, creator())
1130
1131            direct.update(d)
1132            control.update(d)
1133            assert_eq()
1134
1135            kw = dict([(ee.a, ee) for ee in [e, creator()]])
1136            direct.update(**kw)
1137            control.update(**kw)
1138            assert_eq()
1139
1140    def _test_dict_bulk(self, typecallable, creator=None):
1141        if creator is None:
1142            creator = self.dictable_entity
1143
1144        class Foo(object):
1145            pass
1146
1147        canary = Canary()
1148        instrumentation.register_class(Foo)
1149        d = attributes.register_attribute(
1150            Foo,
1151            "attr",
1152            uselist=True,
1153            typecallable=typecallable,
1154            useobject=True,
1155        )
1156        canary.listen(d)
1157
1158        obj = Foo()
1159        direct = obj.attr
1160
1161        e1 = creator()
1162        collections.collection_adapter(direct).append_with_event(e1)
1163
1164        like_me = typecallable()
1165        e2 = creator()
1166        like_me.set(e2)
1167
1168        self.assert_(obj.attr is direct)
1169        obj.attr = like_me
1170        self.assert_(obj.attr is not direct)
1171        self.assert_(obj.attr is not like_me)
1172        self.assert_(
1173            set(collections.collection_adapter(obj.attr)) == set([e2])
1174        )
1175        self.assert_(e1 in canary.removed)
1176        self.assert_(e2 in canary.added)
1177
1178        # key validity on bulk assignment is a basic feature of
1179        # MappedCollection but is not present in basic, @converter-less
1180        # dict collections.
1181        e3 = creator()
1182        real_dict = dict(keyignored1=e3)
1183        obj.attr = real_dict
1184        self.assert_(obj.attr is not real_dict)
1185        self.assert_("keyignored1" not in obj.attr)
1186        eq_(set(collections.collection_adapter(obj.attr)), set([e3]))
1187        self.assert_(e2 in canary.removed)
1188        self.assert_(e3 in canary.added)
1189
1190        obj.attr = typecallable()
1191        eq_(list(collections.collection_adapter(obj.attr)), [])
1192
1193        e4 = creator()
1194        try:
1195            obj.attr = [e4]
1196            self.assert_(False)
1197        except TypeError:
1198            self.assert_(e4 not in canary.data)
1199
1200    def test_dict(self):
1201        assert_raises_message(
1202            sa_exc.ArgumentError,
1203            "Type InstrumentedDict must elect an appender "
1204            "method to be a collection class",
1205            self._test_adapter,
1206            dict,
1207            self.dictable_entity,
1208            to_set=lambda c: set(c.values()),
1209        )
1210
1211        assert_raises_message(
1212            sa_exc.ArgumentError,
1213            "Type InstrumentedDict must elect an appender method "
1214            "to be a collection class",
1215            self._test_dict,
1216            dict,
1217        )
1218
1219    def test_dict_subclass(self):
1220        class MyDict(dict):
1221            @collection.appender
1222            @collection.internally_instrumented
1223            def set(self, item, _sa_initiator=None):
1224                self.__setitem__(item.a, item, _sa_initiator=_sa_initiator)
1225
1226            @collection.remover
1227            @collection.internally_instrumented
1228            def _remove(self, item, _sa_initiator=None):
1229                self.__delitem__(item.a, _sa_initiator=_sa_initiator)
1230
1231        self._test_adapter(
1232            MyDict, self.dictable_entity, to_set=lambda c: set(c.values())
1233        )
1234        self._test_dict(MyDict)
1235        self._test_dict_bulk(MyDict)
1236        self.assert_(getattr(MyDict, "_sa_instrumented") == id(MyDict))
1237
1238    def test_dict_subclass2(self):
1239        class MyEasyDict(collections.MappedCollection):
1240            def __init__(self):
1241                super(MyEasyDict, self).__init__(lambda e: e.a)
1242
1243        self._test_adapter(
1244            MyEasyDict, self.dictable_entity, to_set=lambda c: set(c.values())
1245        )
1246        self._test_dict(MyEasyDict)
1247        self._test_dict_bulk(MyEasyDict)
1248        self.assert_(getattr(MyEasyDict, "_sa_instrumented") == id(MyEasyDict))
1249
1250    def test_dict_subclass3(self):
1251        class MyOrdered(util.OrderedDict, collections.MappedCollection):
1252            def __init__(self):
1253                collections.MappedCollection.__init__(self, lambda e: e.a)
1254                util.OrderedDict.__init__(self)
1255
1256        self._test_adapter(
1257            MyOrdered, self.dictable_entity, to_set=lambda c: set(c.values())
1258        )
1259        self._test_dict(MyOrdered)
1260        self._test_dict_bulk(MyOrdered)
1261        self.assert_(getattr(MyOrdered, "_sa_instrumented") == id(MyOrdered))
1262
1263    def test_dict_duck(self):
1264        class DictLike(object):
1265            def __init__(self):
1266                self.data = dict()
1267
1268            @collection.appender
1269            @collection.replaces(1)
1270            def set(self, item):
1271                current = self.data.get(item.a, None)
1272                self.data[item.a] = item
1273                return current
1274
1275            @collection.remover
1276            def _remove(self, item):
1277                del self.data[item.a]
1278
1279            def __setitem__(self, key, value):
1280                self.data[key] = value
1281
1282            def __getitem__(self, key):
1283                return self.data[key]
1284
1285            def __delitem__(self, key):
1286                del self.data[key]
1287
1288            def values(self):
1289                return list(self.data.values())
1290
1291            def __contains__(self, key):
1292                return key in self.data
1293
1294            @collection.iterator
1295            def itervalues(self):
1296                return iter(self.data.values())
1297
1298            __hash__ = object.__hash__
1299
1300            def __eq__(self, other):
1301                return self.data == other
1302
1303            def __repr__(self):
1304                return "DictLike(%s)" % repr(self.data)
1305
1306        self._test_adapter(
1307            DictLike, self.dictable_entity, to_set=lambda c: set(c.values())
1308        )
1309        self._test_dict(DictLike)
1310        self._test_dict_bulk(DictLike)
1311        self.assert_(getattr(DictLike, "_sa_instrumented") == id(DictLike))
1312
1313    def test_dict_emulates(self):
1314        class DictIsh(object):
1315            __emulates__ = dict
1316
1317            def __init__(self):
1318                self.data = dict()
1319
1320            @collection.appender
1321            @collection.replaces(1)
1322            def set(self, item):
1323                current = self.data.get(item.a, None)
1324                self.data[item.a] = item
1325                return current
1326
1327            @collection.remover
1328            def _remove(self, item):
1329                del self.data[item.a]
1330
1331            def __setitem__(self, key, value):
1332                self.data[key] = value
1333
1334            def __getitem__(self, key):
1335                return self.data[key]
1336
1337            def __delitem__(self, key):
1338                del self.data[key]
1339
1340            def values(self):
1341                return list(self.data.values())
1342
1343            def __contains__(self, key):
1344                return key in self.data
1345
1346            @collection.iterator
1347            def itervalues(self):
1348                return iter(self.data.values())
1349
1350            __hash__ = object.__hash__
1351
1352            def __eq__(self, other):
1353                return self.data == other
1354
1355            def __repr__(self):
1356                return "DictIsh(%s)" % repr(self.data)
1357
1358        self._test_adapter(
1359            DictIsh, self.dictable_entity, to_set=lambda c: set(c.values())
1360        )
1361        self._test_dict(DictIsh)
1362        self._test_dict_bulk(DictIsh)
1363        self.assert_(getattr(DictIsh, "_sa_instrumented") == id(DictIsh))
1364
1365    def _test_object(self, typecallable, creator=None):
1366        if creator is None:
1367            creator = self.entity_maker
1368
1369        class Foo(object):
1370            pass
1371
1372        canary = Canary()
1373        instrumentation.register_class(Foo)
1374        d = attributes.register_attribute(
1375            Foo,
1376            "attr",
1377            uselist=True,
1378            typecallable=typecallable,
1379            useobject=True,
1380        )
1381        canary.listen(d)
1382
1383        obj = Foo()
1384        adapter = collections.collection_adapter(obj.attr)
1385        direct = obj.attr
1386        control = set()
1387
1388        def assert_eq():
1389            self.assert_(set(direct) == canary.data)
1390            self.assert_(set(adapter) == canary.data)
1391            self.assert_(direct == control)
1392
1393        # There is no API for object collections.  We'll make one up
1394        # for the purposes of the test.
1395        e = creator()
1396        direct.push(e)
1397        control.add(e)
1398        assert_eq()
1399
1400        direct.zark(e)
1401        control.remove(e)
1402        assert_eq()
1403
1404        e = creator()
1405        direct.maybe_zark(e)
1406        control.discard(e)
1407        assert_eq()
1408
1409        e = creator()
1410        direct.push(e)
1411        control.add(e)
1412        assert_eq()
1413
1414        e = creator()
1415        direct.maybe_zark(e)
1416        control.discard(e)
1417        assert_eq()
1418
1419    def test_object_duck(self):
1420        class MyCollection(object):
1421            def __init__(self):
1422                self.data = set()
1423
1424            @collection.appender
1425            def push(self, item):
1426                self.data.add(item)
1427
1428            @collection.remover
1429            def zark(self, item):
1430                self.data.remove(item)
1431
1432            @collection.removes_return()
1433            def maybe_zark(self, item):
1434                if item in self.data:
1435                    self.data.remove(item)
1436                    return item
1437
1438            @collection.iterator
1439            def __iter__(self):
1440                return iter(self.data)
1441
1442            __hash__ = object.__hash__
1443
1444            def __eq__(self, other):
1445                return self.data == other
1446
1447        self._test_adapter(MyCollection)
1448        self._test_object(MyCollection)
1449        self.assert_(
1450            getattr(MyCollection, "_sa_instrumented") == id(MyCollection)
1451        )
1452
1453    def test_object_emulates(self):
1454        class MyCollection2(object):
1455            __emulates__ = None
1456
1457            def __init__(self):
1458                self.data = set()
1459
1460            # looks like a list
1461
1462            def append(self, item):
1463                assert False
1464
1465            @collection.appender
1466            def push(self, item):
1467                self.data.add(item)
1468
1469            @collection.remover
1470            def zark(self, item):
1471                self.data.remove(item)
1472
1473            @collection.removes_return()
1474            def maybe_zark(self, item):
1475                if item in self.data:
1476                    self.data.remove(item)
1477                    return item
1478
1479            @collection.iterator
1480            def __iter__(self):
1481                return iter(self.data)
1482
1483            __hash__ = object.__hash__
1484
1485            def __eq__(self, other):
1486                return self.data == other
1487
1488        self._test_adapter(MyCollection2)
1489        self._test_object(MyCollection2)
1490        self.assert_(
1491            getattr(MyCollection2, "_sa_instrumented") == id(MyCollection2)
1492        )
1493
1494    def test_recipes(self):
1495        class Custom(object):
1496            def __init__(self):
1497                self.data = []
1498
1499            @collection.appender
1500            @collection.adds("entity")
1501            def put(self, entity):
1502                self.data.append(entity)
1503
1504            @collection.remover
1505            @collection.removes(1)
1506            def remove(self, entity):
1507                self.data.remove(entity)
1508
1509            @collection.adds(1)
1510            def push(self, *args):
1511                self.data.append(args[0])
1512
1513            @collection.removes("entity")
1514            def yank(self, entity, arg):
1515                self.data.remove(entity)
1516
1517            @collection.replaces(2)
1518            def replace(self, arg, entity, **kw):
1519                self.data.insert(0, entity)
1520                return self.data.pop()
1521
1522            @collection.removes_return()
1523            def pop(self, key):
1524                return self.data.pop()
1525
1526            @collection.iterator
1527            def __iter__(self):
1528                return iter(self.data)
1529
1530        class Foo(object):
1531            pass
1532
1533        canary = Canary()
1534        instrumentation.register_class(Foo)
1535        d = attributes.register_attribute(
1536            Foo, "attr", uselist=True, typecallable=Custom, useobject=True
1537        )
1538        canary.listen(d)
1539
1540        obj = Foo()
1541        adapter = collections.collection_adapter(obj.attr)
1542        direct = obj.attr
1543        control = list()
1544
1545        def assert_eq():
1546            self.assert_(set(direct) == canary.data)
1547            self.assert_(set(adapter) == canary.data)
1548            self.assert_(list(direct) == control)
1549
1550        creator = self.entity_maker
1551
1552        e1 = creator()
1553        direct.put(e1)
1554        control.append(e1)
1555        assert_eq()
1556
1557        e2 = creator()
1558        direct.put(entity=e2)
1559        control.append(e2)
1560        assert_eq()
1561
1562        direct.remove(e2)
1563        control.remove(e2)
1564        assert_eq()
1565
1566        direct.remove(entity=e1)
1567        control.remove(e1)
1568        assert_eq()
1569
1570        e3 = creator()
1571        direct.push(e3)
1572        control.append(e3)
1573        assert_eq()
1574
1575        direct.yank(e3, "blah")
1576        control.remove(e3)
1577        assert_eq()
1578
1579        e4, e5, e6, e7 = creator(), creator(), creator(), creator()
1580        direct.put(e4)
1581        direct.put(e5)
1582        control.append(e4)
1583        control.append(e5)
1584
1585        dr1 = direct.replace("foo", e6, bar="baz")
1586        control.insert(0, e6)
1587        cr1 = control.pop()
1588        assert_eq()
1589        self.assert_(dr1 is cr1)
1590
1591        dr2 = direct.replace(arg=1, entity=e7)
1592        control.insert(0, e7)
1593        cr2 = control.pop()
1594        assert_eq()
1595        self.assert_(dr2 is cr2)
1596
1597        dr3 = direct.pop("blah")
1598        cr3 = control.pop()
1599        assert_eq()
1600        self.assert_(dr3 is cr3)
1601
1602    def test_lifecycle(self):
1603        class Foo(object):
1604            pass
1605
1606        canary = Canary()
1607        creator = self.entity_maker
1608        instrumentation.register_class(Foo)
1609        d = attributes.register_attribute(
1610            Foo, "attr", uselist=True, useobject=True
1611        )
1612        canary.listen(d)
1613
1614        obj = Foo()
1615        col1 = obj.attr
1616
1617        e1 = creator()
1618        obj.attr.append(e1)
1619
1620        e2 = creator()
1621        bulk1 = [e2]
1622        # empty & sever col1 from obj
1623        obj.attr = bulk1
1624
1625        # as of [ticket:3913] the old collection
1626        # remains unchanged
1627        self.assert_(len(col1) == 1)
1628
1629        self.assert_(len(canary.data) == 1)
1630        self.assert_(obj.attr is not col1)
1631        self.assert_(obj.attr is not bulk1)
1632        self.assert_(obj.attr == bulk1)
1633
1634        e3 = creator()
1635        col1.append(e3)
1636        self.assert_(e3 not in canary.data)
1637        self.assert_(collections.collection_adapter(col1) is None)
1638
1639        obj.attr[0] = e3
1640        self.assert_(e3 in canary.data)
1641
1642
1643class DictHelpersTest(fixtures.MappedTest):
1644    @classmethod
1645    def define_tables(cls, metadata):
1646        Table(
1647            "parents",
1648            metadata,
1649            Column(
1650                "id", Integer, primary_key=True, test_needs_autoincrement=True
1651            ),
1652            Column("label", String(128)),
1653        )
1654        Table(
1655            "children",
1656            metadata,
1657            Column(
1658                "id", Integer, primary_key=True, test_needs_autoincrement=True
1659            ),
1660            Column(
1661                "parent_id", Integer, ForeignKey("parents.id"), nullable=False
1662            ),
1663            Column("a", String(128)),
1664            Column("b", String(128)),
1665            Column("c", String(128)),
1666        )
1667
1668    @classmethod
1669    def setup_classes(cls):
1670        class Parent(cls.Basic):
1671            def __init__(self, label=None):
1672                self.label = label
1673
1674        class Child(cls.Basic):
1675            def __init__(self, a=None, b=None, c=None):
1676                self.a = a
1677                self.b = b
1678                self.c = c
1679
1680    def _test_scalar_mapped(self, collection_class):
1681        parents, children, Parent, Child = (
1682            self.tables.parents,
1683            self.tables.children,
1684            self.classes.Parent,
1685            self.classes.Child,
1686        )
1687
1688        mapper(Child, children)
1689        mapper(
1690            Parent,
1691            parents,
1692            properties={
1693                "children": relationship(
1694                    Child,
1695                    collection_class=collection_class,
1696                    cascade="all, delete-orphan",
1697                )
1698            },
1699        )
1700
1701        p = Parent()
1702        p.children["foo"] = Child("foo", "value")
1703        p.children["bar"] = Child("bar", "value")
1704        session = create_session()
1705        session.add(p)
1706        session.flush()
1707        pid = p.id
1708        session.expunge_all()
1709
1710        p = session.query(Parent).get(pid)
1711
1712        eq_(set(p.children.keys()), set(["foo", "bar"]))
1713        cid = p.children["foo"].id
1714
1715        collections.collection_adapter(p.children).append_with_event(
1716            Child("foo", "newvalue")
1717        )
1718
1719        session.flush()
1720        session.expunge_all()
1721
1722        p = session.query(Parent).get(pid)
1723
1724        self.assert_(set(p.children.keys()) == set(["foo", "bar"]))
1725        self.assert_(p.children["foo"].id != cid)
1726
1727        self.assert_(
1728            len(list(collections.collection_adapter(p.children))) == 2
1729        )
1730        session.flush()
1731        session.expunge_all()
1732
1733        p = session.query(Parent).get(pid)
1734        self.assert_(
1735            len(list(collections.collection_adapter(p.children))) == 2
1736        )
1737
1738        collections.collection_adapter(p.children).remove_with_event(
1739            p.children["foo"]
1740        )
1741
1742        self.assert_(
1743            len(list(collections.collection_adapter(p.children))) == 1
1744        )
1745        session.flush()
1746        session.expunge_all()
1747
1748        p = session.query(Parent).get(pid)
1749        self.assert_(
1750            len(list(collections.collection_adapter(p.children))) == 1
1751        )
1752
1753        del p.children["bar"]
1754        self.assert_(
1755            len(list(collections.collection_adapter(p.children))) == 0
1756        )
1757        session.flush()
1758        session.expunge_all()
1759
1760        p = session.query(Parent).get(pid)
1761        self.assert_(
1762            len(list(collections.collection_adapter(p.children))) == 0
1763        )
1764
1765    def _test_composite_mapped(self, collection_class):
1766        parents, children, Parent, Child = (
1767            self.tables.parents,
1768            self.tables.children,
1769            self.classes.Parent,
1770            self.classes.Child,
1771        )
1772
1773        mapper(Child, children)
1774        mapper(
1775            Parent,
1776            parents,
1777            properties={
1778                "children": relationship(
1779                    Child,
1780                    collection_class=collection_class,
1781                    cascade="all, delete-orphan",
1782                )
1783            },
1784        )
1785
1786        p = Parent()
1787        p.children[("foo", "1")] = Child("foo", "1", "value 1")
1788        p.children[("foo", "2")] = Child("foo", "2", "value 2")
1789
1790        session = create_session()
1791        session.add(p)
1792        session.flush()
1793        pid = p.id
1794        session.expunge_all()
1795
1796        p = session.query(Parent).get(pid)
1797
1798        self.assert_(
1799            set(p.children.keys()) == set([("foo", "1"), ("foo", "2")])
1800        )
1801        cid = p.children[("foo", "1")].id
1802
1803        collections.collection_adapter(p.children).append_with_event(
1804            Child("foo", "1", "newvalue")
1805        )
1806
1807        session.flush()
1808        session.expunge_all()
1809
1810        p = session.query(Parent).get(pid)
1811
1812        self.assert_(
1813            set(p.children.keys()) == set([("foo", "1"), ("foo", "2")])
1814        )
1815        self.assert_(p.children[("foo", "1")].id != cid)
1816
1817        self.assert_(
1818            len(list(collections.collection_adapter(p.children))) == 2
1819        )
1820
1821    def test_mapped_collection(self):
1822        collection_class = collections.mapped_collection(lambda c: c.a)
1823        self._test_scalar_mapped(collection_class)
1824
1825    def test_mapped_collection2(self):
1826        collection_class = collections.mapped_collection(lambda c: (c.a, c.b))
1827        self._test_composite_mapped(collection_class)
1828
1829    def test_attr_mapped_collection(self):
1830        collection_class = collections.attribute_mapped_collection("a")
1831        self._test_scalar_mapped(collection_class)
1832
1833    def test_declarative_column_mapped(self):
1834        """test that uncompiled attribute usage works with
1835        column_mapped_collection"""
1836
1837        from sqlalchemy.ext.declarative import declarative_base
1838
1839        BaseObject = declarative_base()
1840
1841        class Foo(BaseObject):
1842            __tablename__ = "foo"
1843            id = Column(Integer(), primary_key=True)
1844            bar_id = Column(Integer, ForeignKey("bar.id"))
1845
1846        for spec, obj, expected in (
1847            (Foo.id, Foo(id=3), 3),
1848            ((Foo.id, Foo.bar_id), Foo(id=3, bar_id=12), (3, 12)),
1849        ):
1850            eq_(
1851                collections.column_mapped_collection(spec)().keyfunc(obj),
1852                expected,
1853            )
1854
1855    def test_column_mapped_assertions(self):
1856        assert_raises_message(
1857            sa_exc.ArgumentError,
1858            "Column-based expression object expected "
1859            "for argument 'mapping_spec'; got: 'a'",
1860            collections.column_mapped_collection,
1861            "a",
1862        )
1863        assert_raises_message(
1864            sa_exc.ArgumentError,
1865            "Column-based expression object expected "
1866            "for argument 'mapping_spec'; got: 'a'",
1867            collections.column_mapped_collection,
1868            text("a"),
1869        )
1870
1871    def test_column_mapped_collection(self):
1872        children = self.tables.children
1873
1874        collection_class = collections.column_mapped_collection(children.c.a)
1875        self._test_scalar_mapped(collection_class)
1876
1877    def test_column_mapped_collection2(self):
1878        children = self.tables.children
1879
1880        collection_class = collections.column_mapped_collection(
1881            (children.c.a, children.c.b)
1882        )
1883        self._test_composite_mapped(collection_class)
1884
1885    def test_mixin(self):
1886        class Ordered(util.OrderedDict, collections.MappedCollection):
1887            def __init__(self):
1888                collections.MappedCollection.__init__(self, lambda v: v.a)
1889                util.OrderedDict.__init__(self)
1890
1891        collection_class = Ordered
1892        self._test_scalar_mapped(collection_class)
1893
1894    def test_mixin2(self):
1895        class Ordered2(util.OrderedDict, collections.MappedCollection):
1896            def __init__(self, keyfunc):
1897                collections.MappedCollection.__init__(self, keyfunc)
1898                util.OrderedDict.__init__(self)
1899
1900        def collection_class():
1901            return Ordered2(lambda v: (v.a, v.b))
1902
1903        self._test_composite_mapped(collection_class)
1904
1905
1906class ColumnMappedWSerialize(fixtures.MappedTest):
1907    """test the column_mapped_collection serializer against
1908    multi-table and indirect table edge cases, including
1909    serialization."""
1910
1911    run_create_tables = run_deletes = None
1912
1913    @classmethod
1914    def define_tables(cls, metadata):
1915        Table(
1916            "foo",
1917            metadata,
1918            Column("id", Integer(), primary_key=True),
1919            Column("b", String(128)),
1920        )
1921        Table(
1922            "bar",
1923            metadata,
1924            Column("id", Integer(), primary_key=True),
1925            Column("foo_id", Integer, ForeignKey("foo.id")),
1926            Column("bat_id", Integer),
1927            schema="x",
1928        )
1929
1930    @classmethod
1931    def setup_classes(cls):
1932        class Foo(cls.Basic):
1933            pass
1934
1935        class Bar(Foo):
1936            pass
1937
1938    def test_indirect_table_column_mapped(self):
1939        Foo = self.classes.Foo
1940        Bar = self.classes.Bar
1941        bar = self.tables["x.bar"]
1942        mapper(
1943            Foo, self.tables.foo, properties={"foo_id": self.tables.foo.c.id}
1944        )
1945        mapper(Bar, bar, inherits=Foo, properties={"bar_id": bar.c.id})
1946
1947        bar_spec = Bar(foo_id=1, bar_id=2, bat_id=3)
1948        self._run_test(
1949            [
1950                (Foo.foo_id, bar_spec, 1),
1951                ((Bar.bar_id, Bar.bat_id), bar_spec, (2, 3)),
1952                (Bar.foo_id, bar_spec, 1),
1953                (bar.c.id, bar_spec, 2),
1954            ]
1955        )
1956
1957    def test_selectable_column_mapped(self):
1958        from sqlalchemy import select
1959
1960        s = select([self.tables.foo]).alias()
1961        Foo = self.classes.Foo
1962        mapper(Foo, s)
1963        self._run_test([(Foo.b, Foo(b=5), 5), (s.c.b, Foo(b=5), 5)])
1964
1965    def _run_test(self, specs):
1966        from sqlalchemy.testing.util import picklers
1967
1968        for spec, obj, expected in specs:
1969            coll = collections.column_mapped_collection(spec)()
1970            eq_(coll.keyfunc(obj), expected)
1971            # ensure we do the right thing with __reduce__
1972            for loads, dumps in picklers():
1973                c2 = loads(dumps(coll))
1974                eq_(c2.keyfunc(obj), expected)
1975                c3 = loads(dumps(c2))
1976                eq_(c3.keyfunc(obj), expected)
1977
1978
1979class CustomCollectionsTest(fixtures.MappedTest):
1980    """test the integration of collections with mapped classes."""
1981
1982    @classmethod
1983    def define_tables(cls, metadata):
1984        Table(
1985            "sometable",
1986            metadata,
1987            Column(
1988                "col1",
1989                Integer,
1990                primary_key=True,
1991                test_needs_autoincrement=True,
1992            ),
1993            Column("data", String(30)),
1994        )
1995        Table(
1996            "someothertable",
1997            metadata,
1998            Column(
1999                "col1",
2000                Integer,
2001                primary_key=True,
2002                test_needs_autoincrement=True,
2003            ),
2004            Column("scol1", Integer, ForeignKey("sometable.col1")),
2005            Column("data", String(20)),
2006        )
2007
2008    def test_basic(self):
2009        someothertable, sometable = (
2010            self.tables.someothertable,
2011            self.tables.sometable,
2012        )
2013
2014        class MyList(list):
2015            pass
2016
2017        class Foo(object):
2018            pass
2019
2020        class Bar(object):
2021            pass
2022
2023        mapper(
2024            Foo,
2025            sometable,
2026            properties={"bars": relationship(Bar, collection_class=MyList)},
2027        )
2028        mapper(Bar, someothertable)
2029        f = Foo()
2030        assert isinstance(f.bars, MyList)
2031
2032    def test_lazyload(self):
2033        """test that a 'set' can be used as a collection and can lazyload."""
2034
2035        someothertable, sometable = (
2036            self.tables.someothertable,
2037            self.tables.sometable,
2038        )
2039
2040        class Foo(object):
2041            pass
2042
2043        class Bar(object):
2044            pass
2045
2046        mapper(
2047            Foo,
2048            sometable,
2049            properties={"bars": relationship(Bar, collection_class=set)},
2050        )
2051        mapper(Bar, someothertable)
2052        f = Foo()
2053        f.bars.add(Bar())
2054        f.bars.add(Bar())
2055        sess = create_session()
2056        sess.add(f)
2057        sess.flush()
2058        sess.expunge_all()
2059        f = sess.query(Foo).get(f.col1)
2060        assert len(list(f.bars)) == 2
2061        f.bars.clear()
2062
2063    def test_dict(self):
2064        """test that a 'dict' can be used as a collection and can lazyload."""
2065
2066        someothertable, sometable = (
2067            self.tables.someothertable,
2068            self.tables.sometable,
2069        )
2070
2071        class Foo(object):
2072            pass
2073
2074        class Bar(object):
2075            pass
2076
2077        class AppenderDict(dict):
2078            @collection.appender
2079            def set(self, item):
2080                self[id(item)] = item
2081
2082            @collection.remover
2083            def remove(self, item):
2084                if id(item) in self:
2085                    del self[id(item)]
2086
2087        mapper(
2088            Foo,
2089            sometable,
2090            properties={
2091                "bars": relationship(Bar, collection_class=AppenderDict)
2092            },
2093        )
2094        mapper(Bar, someothertable)
2095        f = Foo()
2096        f.bars.set(Bar())
2097        f.bars.set(Bar())
2098        sess = create_session()
2099        sess.add(f)
2100        sess.flush()
2101        sess.expunge_all()
2102        f = sess.query(Foo).get(f.col1)
2103        assert len(list(f.bars)) == 2
2104        f.bars.clear()
2105
2106    def test_dict_wrapper(self):
2107        """test that the supplied 'dict' wrapper can be used as a
2108        collection and can lazyload."""
2109
2110        someothertable, sometable = (
2111            self.tables.someothertable,
2112            self.tables.sometable,
2113        )
2114
2115        class Foo(object):
2116            pass
2117
2118        class Bar(object):
2119            def __init__(self, data):
2120                self.data = data
2121
2122        mapper(
2123            Foo,
2124            sometable,
2125            properties={
2126                "bars": relationship(
2127                    Bar,
2128                    collection_class=collections.column_mapped_collection(
2129                        someothertable.c.data
2130                    ),
2131                )
2132            },
2133        )
2134        mapper(Bar, someothertable)
2135
2136        f = Foo()
2137        col = collections.collection_adapter(f.bars)
2138        col.append_with_event(Bar("a"))
2139        col.append_with_event(Bar("b"))
2140        sess = create_session()
2141        sess.add(f)
2142        sess.flush()
2143        sess.expunge_all()
2144        f = sess.query(Foo).get(f.col1)
2145        assert len(list(f.bars)) == 2
2146
2147        strongref = list(f.bars.values())
2148        existing = set([id(b) for b in strongref])
2149
2150        col = collections.collection_adapter(f.bars)
2151        col.append_with_event(Bar("b"))
2152        f.bars["a"] = Bar("a")
2153        sess.flush()
2154        sess.expunge_all()
2155        f = sess.query(Foo).get(f.col1)
2156        assert len(list(f.bars)) == 2
2157
2158        replaced = set([id(b) for b in list(f.bars.values())])
2159        ne_(existing, replaced)
2160
2161    def test_list(self):
2162        self._test_list(list)
2163
2164    def test_list_no_setslice(self):
2165        class ListLike(object):
2166            def __init__(self):
2167                self.data = list()
2168
2169            def append(self, item):
2170                self.data.append(item)
2171
2172            def remove(self, item):
2173                self.data.remove(item)
2174
2175            def insert(self, index, item):
2176                self.data.insert(index, item)
2177
2178            def pop(self, index=-1):
2179                return self.data.pop(index)
2180
2181            def extend(self):
2182                assert False
2183
2184            def __len__(self):
2185                return len(self.data)
2186
2187            def __setitem__(self, key, value):
2188                self.data[key] = value
2189
2190            def __getitem__(self, key):
2191                return self.data[key]
2192
2193            def __delitem__(self, key):
2194                del self.data[key]
2195
2196            def __iter__(self):
2197                return iter(self.data)
2198
2199            __hash__ = object.__hash__
2200
2201            def __eq__(self, other):
2202                return self.data == other
2203
2204            def __repr__(self):
2205                return "ListLike(%s)" % repr(self.data)
2206
2207        self._test_list(ListLike)
2208
2209    def _test_list(self, listcls):
2210        someothertable, sometable = (
2211            self.tables.someothertable,
2212            self.tables.sometable,
2213        )
2214
2215        class Parent(object):
2216            pass
2217
2218        class Child(object):
2219            pass
2220
2221        mapper(
2222            Parent,
2223            sometable,
2224            properties={
2225                "children": relationship(Child, collection_class=listcls)
2226            },
2227        )
2228        mapper(Child, someothertable)
2229
2230        control = list()
2231        p = Parent()
2232
2233        o = Child()
2234        control.append(o)
2235        p.children.append(o)
2236        assert control == p.children
2237        assert control == list(p.children)
2238
2239        o = [Child(), Child(), Child(), Child()]
2240        control.extend(o)
2241        p.children.extend(o)
2242        assert control == p.children
2243        assert control == list(p.children)
2244
2245        assert control[0] == p.children[0]
2246        assert control[-1] == p.children[-1]
2247        assert control[1:3] == p.children[1:3]
2248
2249        del control[1]
2250        del p.children[1]
2251        assert control == p.children
2252        assert control == list(p.children)
2253
2254        o = [Child()]
2255        control[1:3] = o
2256
2257        p.children[1:3] = o
2258        assert control == p.children
2259        assert control == list(p.children)
2260
2261        o = [Child(), Child(), Child(), Child()]
2262        control[1:3] = o
2263        p.children[1:3] = o
2264        assert control == p.children
2265        assert control == list(p.children)
2266
2267        o = [Child(), Child(), Child(), Child()]
2268        control[-1:-2] = o
2269        p.children[-1:-2] = o
2270        assert control == p.children
2271        assert control == list(p.children)
2272
2273        o = [Child(), Child(), Child(), Child()]
2274        control[4:] = o
2275        p.children[4:] = o
2276        assert control == p.children
2277        assert control == list(p.children)
2278
2279        o = Child()
2280        control.insert(0, o)
2281        p.children.insert(0, o)
2282        assert control == p.children
2283        assert control == list(p.children)
2284
2285        o = Child()
2286        control.insert(3, o)
2287        p.children.insert(3, o)
2288        assert control == p.children
2289        assert control == list(p.children)
2290
2291        o = Child()
2292        control.insert(999, o)
2293        p.children.insert(999, o)
2294        assert control == p.children
2295        assert control == list(p.children)
2296
2297        del control[0:1]
2298        del p.children[0:1]
2299        assert control == p.children
2300        assert control == list(p.children)
2301
2302        del control[1:1]
2303        del p.children[1:1]
2304        assert control == p.children
2305        assert control == list(p.children)
2306
2307        del control[1:3]
2308        del p.children[1:3]
2309        assert control == p.children
2310        assert control == list(p.children)
2311
2312        del control[7:]
2313        del p.children[7:]
2314        assert control == p.children
2315        assert control == list(p.children)
2316
2317        assert control.pop() == p.children.pop()
2318        assert control == p.children
2319        assert control == list(p.children)
2320
2321        assert control.pop(0) == p.children.pop(0)
2322        assert control == p.children
2323        assert control == list(p.children)
2324
2325        assert control.pop(2) == p.children.pop(2)
2326        assert control == p.children
2327        assert control == list(p.children)
2328
2329        o = Child()
2330        control.insert(2, o)
2331        p.children.insert(2, o)
2332        assert control == p.children
2333        assert control == list(p.children)
2334
2335        control.remove(o)
2336        p.children.remove(o)
2337        assert control == p.children
2338        assert control == list(p.children)
2339
2340    def test_custom(self):
2341        someothertable, sometable = (
2342            self.tables.someothertable,
2343            self.tables.sometable,
2344        )
2345
2346        class Parent(object):
2347            pass
2348
2349        class Child(object):
2350            pass
2351
2352        class MyCollection(object):
2353            def __init__(self):
2354                self.data = []
2355
2356            @collection.appender
2357            def append(self, value):
2358                self.data.append(value)
2359
2360            @collection.remover
2361            def remove(self, value):
2362                self.data.remove(value)
2363
2364            @collection.iterator
2365            def __iter__(self):
2366                return iter(self.data)
2367
2368        mapper(
2369            Parent,
2370            sometable,
2371            properties={
2372                "children": relationship(Child, collection_class=MyCollection)
2373            },
2374        )
2375        mapper(Child, someothertable)
2376
2377        control = list()
2378        p1 = Parent()
2379
2380        o = Child()
2381        control.append(o)
2382        p1.children.append(o)
2383        assert control == list(p1.children)
2384
2385        o = Child()
2386        control.append(o)
2387        p1.children.append(o)
2388        assert control == list(p1.children)
2389
2390        o = Child()
2391        control.append(o)
2392        p1.children.append(o)
2393        assert control == list(p1.children)
2394
2395        sess = create_session()
2396        sess.add(p1)
2397        sess.flush()
2398        sess.expunge_all()
2399
2400        p2 = sess.query(Parent).get(p1.col1)
2401        o = list(p2.children)
2402        assert len(o) == 3
2403
2404
2405class InstrumentationTest(fixtures.ORMTest):
2406    def test_uncooperative_descriptor_in_sweep(self):
2407        class DoNotTouch(object):
2408            def __get__(self, obj, owner):
2409                raise AttributeError
2410
2411        class Touchy(list):
2412            no_touch = DoNotTouch()
2413
2414        assert "no_touch" in Touchy.__dict__
2415        assert not hasattr(Touchy, "no_touch")
2416        assert "no_touch" in dir(Touchy)
2417
2418        collections._instrument_class(Touchy)
2419
2420    def test_referenced_by_owner(self):
2421        class Foo(object):
2422            pass
2423
2424        instrumentation.register_class(Foo)
2425        attributes.register_attribute(
2426            Foo, "attr", uselist=True, useobject=True
2427        )
2428
2429        f1 = Foo()
2430        f1.attr.append(3)
2431
2432        adapter = collections.collection_adapter(f1.attr)
2433        assert adapter._referenced_by_owner
2434
2435        f1.attr = []
2436        assert not adapter._referenced_by_owner
2437