1from operator import and_
2
3import sqlalchemy as sa
4from sqlalchemy import exc as sa_exc
5from sqlalchemy import ForeignKey
6from sqlalchemy import Integer
7from sqlalchemy import String
8from sqlalchemy import text
9from sqlalchemy import util
10from sqlalchemy.orm import attributes
11from sqlalchemy.orm import create_session
12from sqlalchemy.orm import instrumentation
13from sqlalchemy.orm import mapper
14from sqlalchemy.orm import relationship
15import sqlalchemy.orm.collections as collections
16from sqlalchemy.orm.collections import collection
17from sqlalchemy.testing import assert_raises
18from sqlalchemy.testing import assert_raises_message
19from sqlalchemy.testing import eq_
20from sqlalchemy.testing import fixtures
21from sqlalchemy.testing import ne_
22from sqlalchemy.testing import uses_deprecated
23from sqlalchemy.testing.schema import Column
24from sqlalchemy.testing.schema import Table
25
26
27class Canary(sa.orm.interfaces.AttributeExtension):
28    def __init__(self):
29        self.data = set()
30        self.added = set()
31        self.removed = set()
32
33    def append(self, obj, value, initiator):
34        assert value not in self.added
35        self.data.add(value)
36        self.added.add(value)
37        return value
38
39    def remove(self, obj, value, initiator):
40        assert value not in self.removed
41        self.data.remove(value)
42        self.removed.add(value)
43
44    def set(self, obj, value, oldvalue, initiator):
45        if isinstance(value, str):
46            value = CollectionsTest.entity_maker()
47
48        if oldvalue is not None:
49            self.remove(obj, oldvalue, None)
50        self.append(obj, value, None)
51        return value
52
53
54class CollectionsTest(fixtures.ORMTest):
55    class Entity(object):
56        def __init__(self, a=None, b=None, c=None):
57            self.a = a
58            self.b = b
59            self.c = c
60
61        def __repr__(self):
62            return str((id(self), self.a, self.b, self.c))
63
64    @classmethod
65    def setup_class(cls):
66        instrumentation.register_class(cls.Entity)
67
68    @classmethod
69    def teardown_class(cls):
70        instrumentation.unregister_class(cls.Entity)
71        super(CollectionsTest, cls).teardown_class()
72
73    _entity_id = 1
74
75    @classmethod
76    def entity_maker(cls):
77        cls._entity_id += 1
78        return cls.Entity(cls._entity_id)
79
80    @classmethod
81    def dictable_entity(cls, a=None, b=None, c=None):
82        id_ = cls._entity_id = cls._entity_id + 1
83        return cls.Entity(a or str(id_), b or "value %s" % id, c)
84
85    def _test_adapter(self, typecallable, creator=None, to_set=None):
86        if creator is None:
87            creator = self.entity_maker
88
89        class Foo(object):
90            pass
91
92        canary = Canary()
93        instrumentation.register_class(Foo)
94        attributes.register_attribute(
95            Foo,
96            "attr",
97            uselist=True,
98            extension=canary,
99            typecallable=typecallable,
100            useobject=True,
101        )
102
103        obj = Foo()
104        adapter = collections.collection_adapter(obj.attr)
105        direct = obj.attr
106        if to_set is None:
107
108            def to_set(col):
109                return set(col)
110
111        def assert_eq():
112            self.assert_(to_set(direct) == canary.data)
113            self.assert_(set(adapter) == canary.data)
114
115        def assert_ne():
116            self.assert_(to_set(direct) != canary.data)
117
118        e1, e2 = creator(), creator()
119
120        adapter.append_with_event(e1)
121        assert_eq()
122
123        adapter.append_without_event(e2)
124        assert_ne()
125        canary.data.add(e2)
126        assert_eq()
127
128        adapter.remove_without_event(e2)
129        assert_ne()
130        canary.data.remove(e2)
131        assert_eq()
132
133        adapter.remove_with_event(e1)
134        assert_eq()
135
136    def _test_list(self, typecallable, creator=None):
137        if creator is None:
138            creator = self.entity_maker
139
140        class Foo(object):
141            pass
142
143        canary = Canary()
144        instrumentation.register_class(Foo)
145        attributes.register_attribute(
146            Foo,
147            "attr",
148            uselist=True,
149            extension=canary,
150            typecallable=typecallable,
151            useobject=True,
152        )
153
154        obj = Foo()
155        adapter = collections.collection_adapter(obj.attr)
156        direct = obj.attr
157        control = list()
158
159        def assert_eq():
160            eq_(set(direct), canary.data)
161            eq_(set(adapter), canary.data)
162            eq_(direct, control)
163
164        # assume append() is available for list tests
165        e = creator()
166        direct.append(e)
167        control.append(e)
168        assert_eq()
169
170        if hasattr(direct, "pop"):
171            direct.pop()
172            control.pop()
173            assert_eq()
174
175        if hasattr(direct, "__setitem__"):
176            e = creator()
177            direct.append(e)
178            control.append(e)
179
180            e = creator()
181            direct[0] = e
182            control[0] = e
183            assert_eq()
184
185            if util.reduce(
186                and_,
187                [
188                    hasattr(direct, a)
189                    for a in ("__delitem__", "insert", "__len__")
190                ],
191                True,
192            ):
193                values = [creator(), creator(), creator(), creator()]
194                direct[slice(0, 1)] = values
195                control[slice(0, 1)] = values
196                assert_eq()
197
198                values = [creator(), creator()]
199                direct[slice(0, -1, 2)] = values
200                control[slice(0, -1, 2)] = values
201                assert_eq()
202
203                values = [creator()]
204                direct[slice(0, -1)] = values
205                control[slice(0, -1)] = values
206                assert_eq()
207
208                values = [creator(), creator(), creator()]
209                control[:] = values
210                direct[:] = values
211
212                def invalid():
213                    direct[slice(0, 6, 2)] = [creator()]
214
215                assert_raises(ValueError, invalid)
216
217        if hasattr(direct, "__delitem__"):
218            e = creator()
219            direct.append(e)
220            control.append(e)
221            del direct[-1]
222            del control[-1]
223            assert_eq()
224
225            if hasattr(direct, "__getslice__"):
226                for e in [creator(), creator(), creator(), creator()]:
227                    direct.append(e)
228                    control.append(e)
229
230                del direct[:-3]
231                del control[:-3]
232                assert_eq()
233
234                del direct[0:1]
235                del control[0:1]
236                assert_eq()
237
238                del direct[::2]
239                del control[::2]
240                assert_eq()
241
242        if hasattr(direct, "remove"):
243            e = creator()
244            direct.append(e)
245            control.append(e)
246
247            direct.remove(e)
248            control.remove(e)
249            assert_eq()
250
251        if hasattr(direct, "__setitem__") or hasattr(direct, "__setslice__"):
252
253            values = [creator(), creator()]
254            direct[:] = values
255            control[:] = values
256            assert_eq()
257
258            # test slice assignment where
259            # slice size goes over the number of items
260            values = [creator(), creator()]
261            direct[1:3] = values
262            control[1:3] = values
263            assert_eq()
264
265            values = [creator(), creator()]
266            direct[0:1] = values
267            control[0:1] = values
268            assert_eq()
269
270            values = [creator()]
271            direct[0:] = values
272            control[0:] = values
273            assert_eq()
274
275            values = [creator()]
276            direct[:1] = values
277            control[:1] = values
278            assert_eq()
279
280            values = [creator()]
281            direct[-1::2] = values
282            control[-1::2] = values
283            assert_eq()
284
285            values = [creator()] * len(direct[1::2])
286            direct[1::2] = values
287            control[1::2] = values
288            assert_eq()
289
290            values = [creator(), creator()]
291            direct[-1:-3] = values
292            control[-1:-3] = values
293            assert_eq()
294
295            values = [creator(), creator()]
296            direct[-2:-1] = values
297            control[-2:-1] = values
298            assert_eq()
299
300            values = [creator()]
301            direct[0:0] = values
302            control[0:0] = values
303            assert_eq()
304
305        if hasattr(direct, "__delitem__") or hasattr(direct, "__delslice__"):
306            for i in range(1, 4):
307                e = creator()
308                direct.append(e)
309                control.append(e)
310
311            del direct[-1:]
312            del control[-1:]
313            assert_eq()
314
315            del direct[1:2]
316            del control[1:2]
317            assert_eq()
318
319            del direct[:]
320            del control[:]
321            assert_eq()
322
323        if hasattr(direct, "clear"):
324            for i in range(1, 4):
325                e = creator()
326                direct.append(e)
327                control.append(e)
328
329            direct.clear()
330            control.clear()
331            assert_eq()
332
333        if hasattr(direct, "extend"):
334            values = [creator(), creator(), creator()]
335
336            direct.extend(values)
337            control.extend(values)
338            assert_eq()
339
340        if hasattr(direct, "__iadd__"):
341            values = [creator(), creator(), creator()]
342
343            direct += values
344            control += values
345            assert_eq()
346
347            direct += []
348            control += []
349            assert_eq()
350
351            values = [creator(), creator()]
352            obj.attr += values
353            control += values
354            assert_eq()
355
356        if hasattr(direct, "__imul__"):
357            direct *= 2
358            control *= 2
359            assert_eq()
360
361            obj.attr *= 2
362            control *= 2
363            assert_eq()
364
365    def _test_list_bulk(self, typecallable, creator=None):
366        if creator is None:
367            creator = self.entity_maker
368
369        class Foo(object):
370            pass
371
372        canary = Canary()
373        instrumentation.register_class(Foo)
374        attributes.register_attribute(
375            Foo,
376            "attr",
377            uselist=True,
378            extension=canary,
379            typecallable=typecallable,
380            useobject=True,
381        )
382
383        obj = Foo()
384        direct = obj.attr
385
386        e1 = creator()
387        obj.attr.append(e1)
388
389        like_me = typecallable()
390        e2 = creator()
391        like_me.append(e2)
392
393        self.assert_(obj.attr is direct)
394        obj.attr = like_me
395        self.assert_(obj.attr is not direct)
396        self.assert_(obj.attr is not like_me)
397        self.assert_(set(obj.attr) == set([e2]))
398        self.assert_(e1 in canary.removed)
399        self.assert_(e2 in canary.added)
400
401        e3 = creator()
402        real_list = [e3]
403        obj.attr = real_list
404        self.assert_(obj.attr is not real_list)
405        self.assert_(set(obj.attr) == set([e3]))
406        self.assert_(e2 in canary.removed)
407        self.assert_(e3 in canary.added)
408
409        e4 = creator()
410        try:
411            obj.attr = set([e4])
412            self.assert_(False)
413        except TypeError:
414            self.assert_(e4 not in canary.data)
415            self.assert_(e3 in canary.data)
416
417        e5 = creator()
418        e6 = creator()
419        e7 = creator()
420        obj.attr = [e5, e6, e7]
421        self.assert_(e5 in canary.added)
422        self.assert_(e6 in canary.added)
423        self.assert_(e7 in canary.added)
424
425        obj.attr = [e6, e7]
426        self.assert_(e5 in canary.removed)
427        self.assert_(e6 in canary.added)
428        self.assert_(e7 in canary.added)
429        self.assert_(e6 not in canary.removed)
430        self.assert_(e7 not in canary.removed)
431
432    def test_list(self):
433        self._test_adapter(list)
434        self._test_list(list)
435        self._test_list_bulk(list)
436
437    def test_list_setitem_with_slices(self):
438
439        # this is a "list" that has no __setslice__
440        # or __delslice__ methods.  The __setitem__
441        # and __delitem__ must therefore accept
442        # slice objects (i.e. as in py3k)
443        class ListLike(object):
444            def __init__(self):
445                self.data = list()
446
447            def append(self, item):
448                self.data.append(item)
449
450            def remove(self, item):
451                self.data.remove(item)
452
453            def insert(self, index, item):
454                self.data.insert(index, item)
455
456            def pop(self, index=-1):
457                return self.data.pop(index)
458
459            def extend(self):
460                assert False
461
462            def __len__(self):
463                return len(self.data)
464
465            def __setitem__(self, key, value):
466                self.data[key] = value
467
468            def __getitem__(self, key):
469                return self.data[key]
470
471            def __delitem__(self, key):
472                del self.data[key]
473
474            def __iter__(self):
475                return iter(self.data)
476
477            __hash__ = object.__hash__
478
479            def __eq__(self, other):
480                return self.data == other
481
482            def __repr__(self):
483                return "ListLike(%s)" % repr(self.data)
484
485        self._test_adapter(ListLike)
486        self._test_list(ListLike)
487        self._test_list_bulk(ListLike)
488
489    def test_list_subclass(self):
490        class MyList(list):
491            pass
492
493        self._test_adapter(MyList)
494        self._test_list(MyList)
495        self._test_list_bulk(MyList)
496        self.assert_(getattr(MyList, "_sa_instrumented") == id(MyList))
497
498    def test_list_duck(self):
499        class ListLike(object):
500            def __init__(self):
501                self.data = list()
502
503            def append(self, item):
504                self.data.append(item)
505
506            def remove(self, item):
507                self.data.remove(item)
508
509            def insert(self, index, item):
510                self.data.insert(index, item)
511
512            def pop(self, index=-1):
513                return self.data.pop(index)
514
515            def extend(self):
516                assert False
517
518            def __iter__(self):
519                return iter(self.data)
520
521            __hash__ = object.__hash__
522
523            def __eq__(self, other):
524                return self.data == other
525
526            def __repr__(self):
527                return "ListLike(%s)" % repr(self.data)
528
529        self._test_adapter(ListLike)
530        self._test_list(ListLike)
531        self._test_list_bulk(ListLike)
532        self.assert_(getattr(ListLike, "_sa_instrumented") == id(ListLike))
533
534    def test_list_emulates(self):
535        class ListIsh(object):
536            __emulates__ = list
537
538            def __init__(self):
539                self.data = list()
540
541            def append(self, item):
542                self.data.append(item)
543
544            def remove(self, item):
545                self.data.remove(item)
546
547            def insert(self, index, item):
548                self.data.insert(index, item)
549
550            def pop(self, index=-1):
551                return self.data.pop(index)
552
553            def extend(self):
554                assert False
555
556            def __iter__(self):
557                return iter(self.data)
558
559            __hash__ = object.__hash__
560
561            def __eq__(self, other):
562                return self.data == other
563
564            def __repr__(self):
565                return "ListIsh(%s)" % repr(self.data)
566
567        self._test_adapter(ListIsh)
568        self._test_list(ListIsh)
569        self._test_list_bulk(ListIsh)
570        self.assert_(getattr(ListIsh, "_sa_instrumented") == id(ListIsh))
571
572    def _test_set(self, typecallable, creator=None):
573        if creator is None:
574            creator = self.entity_maker
575
576        class Foo(object):
577            pass
578
579        canary = Canary()
580        instrumentation.register_class(Foo)
581        attributes.register_attribute(
582            Foo,
583            "attr",
584            uselist=True,
585            extension=canary,
586            typecallable=typecallable,
587            useobject=True,
588        )
589
590        obj = Foo()
591        adapter = collections.collection_adapter(obj.attr)
592        direct = obj.attr
593        control = set()
594
595        def assert_eq():
596            eq_(set(direct), canary.data)
597            eq_(set(adapter), canary.data)
598            eq_(direct, control)
599
600        def addall(*values):
601            for item in values:
602                direct.add(item)
603                control.add(item)
604            assert_eq()
605
606        def zap():
607            for item in list(direct):
608                direct.remove(item)
609            control.clear()
610
611        addall(creator())
612
613        e = creator()
614        addall(e)
615        addall(e)
616
617        if hasattr(direct, "remove"):
618            e = creator()
619            addall(e)
620
621            direct.remove(e)
622            control.remove(e)
623            assert_eq()
624
625            e = creator()
626            try:
627                direct.remove(e)
628            except KeyError:
629                assert_eq()
630                self.assert_(e not in canary.removed)
631            else:
632                self.assert_(False)
633
634        if hasattr(direct, "discard"):
635            e = creator()
636            addall(e)
637
638            direct.discard(e)
639            control.discard(e)
640            assert_eq()
641
642            e = creator()
643            direct.discard(e)
644            self.assert_(e not in canary.removed)
645            assert_eq()
646
647        if hasattr(direct, "update"):
648            zap()
649            e = creator()
650            addall(e)
651
652            values = set([e, creator(), creator()])
653
654            direct.update(values)
655            control.update(values)
656            assert_eq()
657
658        if hasattr(direct, "__ior__"):
659            zap()
660            e = creator()
661            addall(e)
662
663            values = set([e, creator(), creator()])
664
665            direct |= values
666            control |= values
667            assert_eq()
668
669            # cover self-assignment short-circuit
670            values = set([e, creator(), creator()])
671            obj.attr |= values
672            control |= values
673            assert_eq()
674
675            values = frozenset([e, creator()])
676            obj.attr |= values
677            control |= values
678            assert_eq()
679
680            try:
681                direct |= [e, creator()]
682                assert False
683            except TypeError:
684                assert True
685
686        addall(creator(), creator())
687        direct.clear()
688        control.clear()
689        assert_eq()
690
691        # note: the clear test previously needs
692        # to have executed in order for this to
693        # pass in all cases; else there's the possibility
694        # of non-deterministic behavior.
695        addall(creator())
696        direct.pop()
697        control.pop()
698        assert_eq()
699
700        if hasattr(direct, "difference_update"):
701            zap()
702            e = creator()
703            addall(creator(), creator())
704            values = set([creator()])
705
706            direct.difference_update(values)
707            control.difference_update(values)
708            assert_eq()
709            values.update(set([e, creator()]))
710            direct.difference_update(values)
711            control.difference_update(values)
712            assert_eq()
713
714        if hasattr(direct, "__isub__"):
715            zap()
716            e = creator()
717            addall(creator(), creator())
718            values = set([creator()])
719
720            direct -= values
721            control -= values
722            assert_eq()
723            values.update(set([e, creator()]))
724            direct -= values
725            control -= values
726            assert_eq()
727
728            values = set([creator()])
729            obj.attr -= values
730            control -= values
731            assert_eq()
732
733            values = frozenset([creator()])
734            obj.attr -= values
735            control -= values
736            assert_eq()
737
738            try:
739                direct -= [e, creator()]
740                assert False
741            except TypeError:
742                assert True
743
744        if hasattr(direct, "intersection_update"):
745            zap()
746            e = creator()
747            addall(e, creator(), creator())
748            values = set(control)
749
750            direct.intersection_update(values)
751            control.intersection_update(values)
752            assert_eq()
753
754            values.update(set([e, creator()]))
755            direct.intersection_update(values)
756            control.intersection_update(values)
757            assert_eq()
758
759        if hasattr(direct, "__iand__"):
760            zap()
761            e = creator()
762            addall(e, creator(), creator())
763            values = set(control)
764
765            direct &= values
766            control &= values
767            assert_eq()
768
769            values.update(set([e, creator()]))
770            direct &= values
771            control &= values
772            assert_eq()
773
774            values.update(set([creator()]))
775            obj.attr &= values
776            control &= values
777            assert_eq()
778
779            try:
780                direct &= [e, creator()]
781                assert False
782            except TypeError:
783                assert True
784
785        if hasattr(direct, "symmetric_difference_update"):
786            zap()
787            e = creator()
788            addall(e, creator(), creator())
789
790            values = set([e, creator()])
791            direct.symmetric_difference_update(values)
792            control.symmetric_difference_update(values)
793            assert_eq()
794
795            e = creator()
796            addall(e)
797            values = set([e])
798            direct.symmetric_difference_update(values)
799            control.symmetric_difference_update(values)
800            assert_eq()
801
802            values = set()
803            direct.symmetric_difference_update(values)
804            control.symmetric_difference_update(values)
805            assert_eq()
806
807        if hasattr(direct, "__ixor__"):
808            zap()
809            e = creator()
810            addall(e, creator(), creator())
811
812            values = set([e, creator()])
813            direct ^= values
814            control ^= values
815            assert_eq()
816
817            e = creator()
818            addall(e)
819            values = set([e])
820            direct ^= values
821            control ^= values
822            assert_eq()
823
824            values = set()
825            direct ^= values
826            control ^= values
827            assert_eq()
828
829            values = set([creator()])
830            obj.attr ^= values
831            control ^= values
832            assert_eq()
833
834            try:
835                direct ^= [e, creator()]
836                assert False
837            except TypeError:
838                assert True
839
840    def _test_set_bulk(self, typecallable, creator=None):
841        if creator is None:
842            creator = self.entity_maker
843
844        class Foo(object):
845            pass
846
847        canary = Canary()
848        instrumentation.register_class(Foo)
849        attributes.register_attribute(
850            Foo,
851            "attr",
852            uselist=True,
853            extension=canary,
854            typecallable=typecallable,
855            useobject=True,
856        )
857
858        obj = Foo()
859        direct = obj.attr
860
861        e1 = creator()
862        obj.attr.add(e1)
863
864        like_me = typecallable()
865        e2 = creator()
866        like_me.add(e2)
867
868        self.assert_(obj.attr is direct)
869        obj.attr = like_me
870        self.assert_(obj.attr is not direct)
871        self.assert_(obj.attr is not like_me)
872        self.assert_(obj.attr == set([e2]))
873        self.assert_(e1 in canary.removed)
874        self.assert_(e2 in canary.added)
875
876        e3 = creator()
877        real_set = set([e3])
878        obj.attr = real_set
879        self.assert_(obj.attr is not real_set)
880        self.assert_(obj.attr == set([e3]))
881        self.assert_(e2 in canary.removed)
882        self.assert_(e3 in canary.added)
883
884        e4 = creator()
885        try:
886            obj.attr = [e4]
887            self.assert_(False)
888        except TypeError:
889            self.assert_(e4 not in canary.data)
890            self.assert_(e3 in canary.data)
891
892    def test_set(self):
893        self._test_adapter(set)
894        self._test_set(set)
895        self._test_set_bulk(set)
896
897    def test_set_subclass(self):
898        class MySet(set):
899            pass
900
901        self._test_adapter(MySet)
902        self._test_set(MySet)
903        self._test_set_bulk(MySet)
904        self.assert_(getattr(MySet, "_sa_instrumented") == id(MySet))
905
906    def test_set_duck(self):
907        class SetLike(object):
908            def __init__(self):
909                self.data = set()
910
911            def add(self, item):
912                self.data.add(item)
913
914            def remove(self, item):
915                self.data.remove(item)
916
917            def discard(self, item):
918                self.data.discard(item)
919
920            def clear(self):
921                self.data.clear()
922
923            def pop(self):
924                return self.data.pop()
925
926            def update(self, other):
927                self.data.update(other)
928
929            def __iter__(self):
930                return iter(self.data)
931
932            __hash__ = object.__hash__
933
934            def __eq__(self, other):
935                return self.data == other
936
937        self._test_adapter(SetLike)
938        self._test_set(SetLike)
939        self._test_set_bulk(SetLike)
940        self.assert_(getattr(SetLike, "_sa_instrumented") == id(SetLike))
941
942    def test_set_emulates(self):
943        class SetIsh(object):
944            __emulates__ = set
945
946            def __init__(self):
947                self.data = set()
948
949            def add(self, item):
950                self.data.add(item)
951
952            def remove(self, item):
953                self.data.remove(item)
954
955            def discard(self, item):
956                self.data.discard(item)
957
958            def pop(self):
959                return self.data.pop()
960
961            def update(self, other):
962                self.data.update(other)
963
964            def __iter__(self):
965                return iter(self.data)
966
967            def clear(self):
968                self.data.clear()
969
970            __hash__ = object.__hash__
971
972            def __eq__(self, other):
973                return self.data == other
974
975        self._test_adapter(SetIsh)
976        self._test_set(SetIsh)
977        self._test_set_bulk(SetIsh)
978        self.assert_(getattr(SetIsh, "_sa_instrumented") == id(SetIsh))
979
980    def _test_dict(self, typecallable, creator=None):
981        if creator is None:
982            creator = self.dictable_entity
983
984        class Foo(object):
985            pass
986
987        canary = Canary()
988        instrumentation.register_class(Foo)
989        attributes.register_attribute(
990            Foo,
991            "attr",
992            uselist=True,
993            extension=canary,
994            typecallable=typecallable,
995            useobject=True,
996        )
997
998        obj = Foo()
999        adapter = collections.collection_adapter(obj.attr)
1000        direct = obj.attr
1001        control = dict()
1002
1003        def assert_eq():
1004            self.assert_(set(direct.values()) == canary.data)
1005            self.assert_(set(adapter) == canary.data)
1006            self.assert_(direct == control)
1007
1008        def addall(*values):
1009            for item in values:
1010                direct.set(item)
1011                control[item.a] = item
1012            assert_eq()
1013
1014        def zap():
1015            for item in list(adapter):
1016                direct.remove(item)
1017            control.clear()
1018
1019        # assume an 'set' method is available for tests
1020        addall(creator())
1021
1022        if hasattr(direct, "__setitem__"):
1023            e = creator()
1024            direct[e.a] = e
1025            control[e.a] = e
1026            assert_eq()
1027
1028            e = creator(e.a, e.b)
1029            direct[e.a] = e
1030            control[e.a] = e
1031            assert_eq()
1032
1033        if hasattr(direct, "__delitem__"):
1034            e = creator()
1035            addall(e)
1036
1037            del direct[e.a]
1038            del control[e.a]
1039            assert_eq()
1040
1041            e = creator()
1042            try:
1043                del direct[e.a]
1044            except KeyError:
1045                self.assert_(e not in canary.removed)
1046
1047        if hasattr(direct, "clear"):
1048            addall(creator(), creator(), creator())
1049
1050            direct.clear()
1051            control.clear()
1052            assert_eq()
1053
1054            direct.clear()
1055            control.clear()
1056            assert_eq()
1057
1058        if hasattr(direct, "pop"):
1059            e = creator()
1060            addall(e)
1061
1062            direct.pop(e.a)
1063            control.pop(e.a)
1064            assert_eq()
1065
1066            e = creator()
1067            try:
1068                direct.pop(e.a)
1069            except KeyError:
1070                self.assert_(e not in canary.removed)
1071
1072        if hasattr(direct, "popitem"):
1073            zap()
1074            e = creator()
1075            addall(e)
1076
1077            direct.popitem()
1078            control.popitem()
1079            assert_eq()
1080
1081        if hasattr(direct, "setdefault"):
1082            e = creator()
1083
1084            val_a = direct.setdefault(e.a, e)
1085            val_b = control.setdefault(e.a, e)
1086            assert_eq()
1087            self.assert_(val_a is val_b)
1088
1089            val_a = direct.setdefault(e.a, e)
1090            val_b = control.setdefault(e.a, e)
1091            assert_eq()
1092            self.assert_(val_a is val_b)
1093
1094        if hasattr(direct, "update"):
1095            e = creator()
1096            d = dict([(ee.a, ee) for ee in [e, creator(), creator()]])
1097            addall(e, creator())
1098
1099            direct.update(d)
1100            control.update(d)
1101            assert_eq()
1102
1103            kw = dict([(ee.a, ee) for ee in [e, creator()]])
1104            direct.update(**kw)
1105            control.update(**kw)
1106            assert_eq()
1107
1108    def _test_dict_bulk(self, typecallable, creator=None):
1109        if creator is None:
1110            creator = self.dictable_entity
1111
1112        class Foo(object):
1113            pass
1114
1115        canary = Canary()
1116        instrumentation.register_class(Foo)
1117        attributes.register_attribute(
1118            Foo,
1119            "attr",
1120            uselist=True,
1121            extension=canary,
1122            typecallable=typecallable,
1123            useobject=True,
1124        )
1125
1126        obj = Foo()
1127        direct = obj.attr
1128
1129        e1 = creator()
1130        collections.collection_adapter(direct).append_with_event(e1)
1131
1132        like_me = typecallable()
1133        e2 = creator()
1134        like_me.set(e2)
1135
1136        self.assert_(obj.attr is direct)
1137        obj.attr = like_me
1138        self.assert_(obj.attr is not direct)
1139        self.assert_(obj.attr is not like_me)
1140        self.assert_(
1141            set(collections.collection_adapter(obj.attr)) == set([e2])
1142        )
1143        self.assert_(e1 in canary.removed)
1144        self.assert_(e2 in canary.added)
1145
1146        # key validity on bulk assignment is a basic feature of
1147        # MappedCollection but is not present in basic, @converter-less
1148        # dict collections.
1149        e3 = creator()
1150        if isinstance(obj.attr, collections.MappedCollection):
1151            real_dict = dict(badkey=e3)
1152            try:
1153                obj.attr = real_dict
1154                self.assert_(False)
1155            except TypeError:
1156                pass
1157            self.assert_(obj.attr is not real_dict)
1158            self.assert_("badkey" not in obj.attr)
1159            eq_(set(collections.collection_adapter(obj.attr)), set([e2]))
1160            self.assert_(e3 not in canary.added)
1161        else:
1162            real_dict = dict(keyignored1=e3)
1163            obj.attr = real_dict
1164            self.assert_(obj.attr is not real_dict)
1165            self.assert_("keyignored1" not in obj.attr)
1166            eq_(set(collections.collection_adapter(obj.attr)), set([e3]))
1167            self.assert_(e2 in canary.removed)
1168            self.assert_(e3 in canary.added)
1169
1170        obj.attr = typecallable()
1171        eq_(list(collections.collection_adapter(obj.attr)), [])
1172
1173        e4 = creator()
1174        try:
1175            obj.attr = [e4]
1176            self.assert_(False)
1177        except TypeError:
1178            self.assert_(e4 not in canary.data)
1179
1180    def test_dict(self):
1181        assert_raises_message(
1182            sa_exc.ArgumentError,
1183            "Type InstrumentedDict must elect an appender "
1184            "method to be a collection class",
1185            self._test_adapter,
1186            dict,
1187            self.dictable_entity,
1188            to_set=lambda c: set(c.values()),
1189        )
1190
1191        assert_raises_message(
1192            sa_exc.ArgumentError,
1193            "Type InstrumentedDict must elect an appender method "
1194            "to be a collection class",
1195            self._test_dict,
1196            dict,
1197        )
1198
1199    def test_dict_subclass(self):
1200        class MyDict(dict):
1201            @collection.appender
1202            @collection.internally_instrumented
1203            def set(self, item, _sa_initiator=None):
1204                self.__setitem__(item.a, item, _sa_initiator=_sa_initiator)
1205
1206            @collection.remover
1207            @collection.internally_instrumented
1208            def _remove(self, item, _sa_initiator=None):
1209                self.__delitem__(item.a, _sa_initiator=_sa_initiator)
1210
1211        self._test_adapter(
1212            MyDict, self.dictable_entity, to_set=lambda c: set(c.values())
1213        )
1214        self._test_dict(MyDict)
1215        self._test_dict_bulk(MyDict)
1216        self.assert_(getattr(MyDict, "_sa_instrumented") == id(MyDict))
1217
1218    def test_dict_subclass2(self):
1219        class MyEasyDict(collections.MappedCollection):
1220            def __init__(self):
1221                super(MyEasyDict, self).__init__(lambda e: e.a)
1222
1223        self._test_adapter(
1224            MyEasyDict, self.dictable_entity, to_set=lambda c: set(c.values())
1225        )
1226        self._test_dict(MyEasyDict)
1227        self._test_dict_bulk(MyEasyDict)
1228        self.assert_(getattr(MyEasyDict, "_sa_instrumented") == id(MyEasyDict))
1229
1230    def test_dict_subclass3(self):
1231        class MyOrdered(util.OrderedDict, collections.MappedCollection):
1232            def __init__(self):
1233                collections.MappedCollection.__init__(self, lambda e: e.a)
1234                util.OrderedDict.__init__(self)
1235
1236        self._test_adapter(
1237            MyOrdered, self.dictable_entity, to_set=lambda c: set(c.values())
1238        )
1239        self._test_dict(MyOrdered)
1240        self._test_dict_bulk(MyOrdered)
1241        self.assert_(getattr(MyOrdered, "_sa_instrumented") == id(MyOrdered))
1242
1243    def test_dict_subclass4(self):
1244        # tests #2654
1245        class MyDict(collections.MappedCollection):
1246            def __init__(self):
1247                super(MyDict, self).__init__(lambda value: "k%d" % value)
1248
1249            @collection.converter
1250            def _convert(self, dictlike):
1251                for key, value in dictlike.items():
1252                    yield value + 5
1253
1254        class Foo(object):
1255            pass
1256
1257        canary = Canary()
1258
1259        instrumentation.register_class(Foo)
1260        attributes.register_attribute(
1261            Foo,
1262            "attr",
1263            uselist=True,
1264            extension=canary,
1265            typecallable=MyDict,
1266            useobject=True,
1267        )
1268
1269        f = Foo()
1270        f.attr = {"k1": 1, "k2": 2}
1271
1272        eq_(f.attr, {"k7": 7, "k6": 6})
1273
1274    def test_dict_duck(self):
1275        class DictLike(object):
1276            def __init__(self):
1277                self.data = dict()
1278
1279            @collection.appender
1280            @collection.replaces(1)
1281            def set(self, item):
1282                current = self.data.get(item.a, None)
1283                self.data[item.a] = item
1284                return current
1285
1286            @collection.remover
1287            def _remove(self, item):
1288                del self.data[item.a]
1289
1290            def __setitem__(self, key, value):
1291                self.data[key] = value
1292
1293            def __getitem__(self, key):
1294                return self.data[key]
1295
1296            def __delitem__(self, key):
1297                del self.data[key]
1298
1299            def values(self):
1300                return list(self.data.values())
1301
1302            def __contains__(self, key):
1303                return key in self.data
1304
1305            @collection.iterator
1306            def itervalues(self):
1307                return iter(self.data.values())
1308
1309            __hash__ = object.__hash__
1310
1311            def __eq__(self, other):
1312                return self.data == other
1313
1314            def __repr__(self):
1315                return "DictLike(%s)" % repr(self.data)
1316
1317        self._test_adapter(
1318            DictLike, self.dictable_entity, to_set=lambda c: set(c.values())
1319        )
1320        self._test_dict(DictLike)
1321        self._test_dict_bulk(DictLike)
1322        self.assert_(getattr(DictLike, "_sa_instrumented") == id(DictLike))
1323
1324    def test_dict_emulates(self):
1325        class DictIsh(object):
1326            __emulates__ = dict
1327
1328            def __init__(self):
1329                self.data = dict()
1330
1331            @collection.appender
1332            @collection.replaces(1)
1333            def set(self, item):
1334                current = self.data.get(item.a, None)
1335                self.data[item.a] = item
1336                return current
1337
1338            @collection.remover
1339            def _remove(self, item):
1340                del self.data[item.a]
1341
1342            def __setitem__(self, key, value):
1343                self.data[key] = value
1344
1345            def __getitem__(self, key):
1346                return self.data[key]
1347
1348            def __delitem__(self, key):
1349                del self.data[key]
1350
1351            def values(self):
1352                return list(self.data.values())
1353
1354            def __contains__(self, key):
1355                return key in self.data
1356
1357            @collection.iterator
1358            def itervalues(self):
1359                return iter(self.data.values())
1360
1361            __hash__ = object.__hash__
1362
1363            def __eq__(self, other):
1364                return self.data == other
1365
1366            def __repr__(self):
1367                return "DictIsh(%s)" % repr(self.data)
1368
1369        self._test_adapter(
1370            DictIsh, self.dictable_entity, to_set=lambda c: set(c.values())
1371        )
1372        self._test_dict(DictIsh)
1373        self._test_dict_bulk(DictIsh)
1374        self.assert_(getattr(DictIsh, "_sa_instrumented") == id(DictIsh))
1375
1376    def _test_object(self, typecallable, creator=None):
1377        if creator is None:
1378            creator = self.entity_maker
1379
1380        class Foo(object):
1381            pass
1382
1383        canary = Canary()
1384        instrumentation.register_class(Foo)
1385        attributes.register_attribute(
1386            Foo,
1387            "attr",
1388            uselist=True,
1389            extension=canary,
1390            typecallable=typecallable,
1391            useobject=True,
1392        )
1393
1394        obj = Foo()
1395        adapter = collections.collection_adapter(obj.attr)
1396        direct = obj.attr
1397        control = set()
1398
1399        def assert_eq():
1400            self.assert_(set(direct) == canary.data)
1401            self.assert_(set(adapter) == canary.data)
1402            self.assert_(direct == control)
1403
1404        # There is no API for object collections.  We'll make one up
1405        # for the purposes of the test.
1406        e = creator()
1407        direct.push(e)
1408        control.add(e)
1409        assert_eq()
1410
1411        direct.zark(e)
1412        control.remove(e)
1413        assert_eq()
1414
1415        e = creator()
1416        direct.maybe_zark(e)
1417        control.discard(e)
1418        assert_eq()
1419
1420        e = creator()
1421        direct.push(e)
1422        control.add(e)
1423        assert_eq()
1424
1425        e = creator()
1426        direct.maybe_zark(e)
1427        control.discard(e)
1428        assert_eq()
1429
1430    def test_object_duck(self):
1431        class MyCollection(object):
1432            def __init__(self):
1433                self.data = set()
1434
1435            @collection.appender
1436            def push(self, item):
1437                self.data.add(item)
1438
1439            @collection.remover
1440            def zark(self, item):
1441                self.data.remove(item)
1442
1443            @collection.removes_return()
1444            def maybe_zark(self, item):
1445                if item in self.data:
1446                    self.data.remove(item)
1447                    return item
1448
1449            @collection.iterator
1450            def __iter__(self):
1451                return iter(self.data)
1452
1453            __hash__ = object.__hash__
1454
1455            def __eq__(self, other):
1456                return self.data == other
1457
1458        self._test_adapter(MyCollection)
1459        self._test_object(MyCollection)
1460        self.assert_(
1461            getattr(MyCollection, "_sa_instrumented") == id(MyCollection)
1462        )
1463
1464    def test_object_emulates(self):
1465        class MyCollection2(object):
1466            __emulates__ = None
1467
1468            def __init__(self):
1469                self.data = set()
1470
1471            # looks like a list
1472
1473            def append(self, item):
1474                assert False
1475
1476            @collection.appender
1477            def push(self, item):
1478                self.data.add(item)
1479
1480            @collection.remover
1481            def zark(self, item):
1482                self.data.remove(item)
1483
1484            @collection.removes_return()
1485            def maybe_zark(self, item):
1486                if item in self.data:
1487                    self.data.remove(item)
1488                    return item
1489
1490            @collection.iterator
1491            def __iter__(self):
1492                return iter(self.data)
1493
1494            __hash__ = object.__hash__
1495
1496            def __eq__(self, other):
1497                return self.data == other
1498
1499        self._test_adapter(MyCollection2)
1500        self._test_object(MyCollection2)
1501        self.assert_(
1502            getattr(MyCollection2, "_sa_instrumented") == id(MyCollection2)
1503        )
1504
1505    def test_recipes(self):
1506        class Custom(object):
1507            def __init__(self):
1508                self.data = []
1509
1510            @collection.appender
1511            @collection.adds("entity")
1512            def put(self, entity):
1513                self.data.append(entity)
1514
1515            @collection.remover
1516            @collection.removes(1)
1517            def remove(self, entity):
1518                self.data.remove(entity)
1519
1520            @collection.adds(1)
1521            def push(self, *args):
1522                self.data.append(args[0])
1523
1524            @collection.removes("entity")
1525            def yank(self, entity, arg):
1526                self.data.remove(entity)
1527
1528            @collection.replaces(2)
1529            def replace(self, arg, entity, **kw):
1530                self.data.insert(0, entity)
1531                return self.data.pop()
1532
1533            @collection.removes_return()
1534            def pop(self, key):
1535                return self.data.pop()
1536
1537            @collection.iterator
1538            def __iter__(self):
1539                return iter(self.data)
1540
1541        class Foo(object):
1542            pass
1543
1544        canary = Canary()
1545        instrumentation.register_class(Foo)
1546        attributes.register_attribute(
1547            Foo,
1548            "attr",
1549            uselist=True,
1550            extension=canary,
1551            typecallable=Custom,
1552            useobject=True,
1553        )
1554
1555        obj = Foo()
1556        adapter = collections.collection_adapter(obj.attr)
1557        direct = obj.attr
1558        control = list()
1559
1560        def assert_eq():
1561            self.assert_(set(direct) == canary.data)
1562            self.assert_(set(adapter) == canary.data)
1563            self.assert_(list(direct) == control)
1564
1565        creator = self.entity_maker
1566
1567        e1 = creator()
1568        direct.put(e1)
1569        control.append(e1)
1570        assert_eq()
1571
1572        e2 = creator()
1573        direct.put(entity=e2)
1574        control.append(e2)
1575        assert_eq()
1576
1577        direct.remove(e2)
1578        control.remove(e2)
1579        assert_eq()
1580
1581        direct.remove(entity=e1)
1582        control.remove(e1)
1583        assert_eq()
1584
1585        e3 = creator()
1586        direct.push(e3)
1587        control.append(e3)
1588        assert_eq()
1589
1590        direct.yank(e3, "blah")
1591        control.remove(e3)
1592        assert_eq()
1593
1594        e4, e5, e6, e7 = creator(), creator(), creator(), creator()
1595        direct.put(e4)
1596        direct.put(e5)
1597        control.append(e4)
1598        control.append(e5)
1599
1600        dr1 = direct.replace("foo", e6, bar="baz")
1601        control.insert(0, e6)
1602        cr1 = control.pop()
1603        assert_eq()
1604        self.assert_(dr1 is cr1)
1605
1606        dr2 = direct.replace(arg=1, entity=e7)
1607        control.insert(0, e7)
1608        cr2 = control.pop()
1609        assert_eq()
1610        self.assert_(dr2 is cr2)
1611
1612        dr3 = direct.pop("blah")
1613        cr3 = control.pop()
1614        assert_eq()
1615        self.assert_(dr3 is cr3)
1616
1617    def test_lifecycle(self):
1618        class Foo(object):
1619            pass
1620
1621        canary = Canary()
1622        creator = self.entity_maker
1623        instrumentation.register_class(Foo)
1624        attributes.register_attribute(
1625            Foo, "attr", uselist=True, extension=canary, useobject=True
1626        )
1627
1628        obj = Foo()
1629        col1 = obj.attr
1630
1631        e1 = creator()
1632        obj.attr.append(e1)
1633
1634        e2 = creator()
1635        bulk1 = [e2]
1636        # empty & sever col1 from obj
1637        obj.attr = bulk1
1638
1639        # as of [ticket:3913] the old collection
1640        # remains unchanged
1641        self.assert_(len(col1) == 1)
1642
1643        self.assert_(len(canary.data) == 1)
1644        self.assert_(obj.attr is not col1)
1645        self.assert_(obj.attr is not bulk1)
1646        self.assert_(obj.attr == bulk1)
1647
1648        e3 = creator()
1649        col1.append(e3)
1650        self.assert_(e3 not in canary.data)
1651        self.assert_(collections.collection_adapter(col1) is None)
1652
1653        obj.attr[0] = e3
1654        self.assert_(e3 in canary.data)
1655
1656
1657class DictHelpersTest(fixtures.MappedTest):
1658    @classmethod
1659    def define_tables(cls, metadata):
1660        Table(
1661            "parents",
1662            metadata,
1663            Column(
1664                "id", Integer, primary_key=True, test_needs_autoincrement=True
1665            ),
1666            Column("label", String(128)),
1667        )
1668        Table(
1669            "children",
1670            metadata,
1671            Column(
1672                "id", Integer, primary_key=True, test_needs_autoincrement=True
1673            ),
1674            Column(
1675                "parent_id", Integer, ForeignKey("parents.id"), nullable=False
1676            ),
1677            Column("a", String(128)),
1678            Column("b", String(128)),
1679            Column("c", String(128)),
1680        )
1681
1682    @classmethod
1683    def setup_classes(cls):
1684        class Parent(cls.Basic):
1685            def __init__(self, label=None):
1686                self.label = label
1687
1688        class Child(cls.Basic):
1689            def __init__(self, a=None, b=None, c=None):
1690                self.a = a
1691                self.b = b
1692                self.c = c
1693
1694    def _test_scalar_mapped(self, collection_class):
1695        parents, children, Parent, Child = (
1696            self.tables.parents,
1697            self.tables.children,
1698            self.classes.Parent,
1699            self.classes.Child,
1700        )
1701
1702        mapper(Child, children)
1703        mapper(
1704            Parent,
1705            parents,
1706            properties={
1707                "children": relationship(
1708                    Child,
1709                    collection_class=collection_class,
1710                    cascade="all, delete-orphan",
1711                )
1712            },
1713        )
1714
1715        p = Parent()
1716        p.children["foo"] = Child("foo", "value")
1717        p.children["bar"] = Child("bar", "value")
1718        session = create_session()
1719        session.add(p)
1720        session.flush()
1721        pid = p.id
1722        session.expunge_all()
1723
1724        p = session.query(Parent).get(pid)
1725
1726        eq_(set(p.children.keys()), set(["foo", "bar"]))
1727        cid = p.children["foo"].id
1728
1729        collections.collection_adapter(p.children).append_with_event(
1730            Child("foo", "newvalue")
1731        )
1732
1733        session.flush()
1734        session.expunge_all()
1735
1736        p = session.query(Parent).get(pid)
1737
1738        self.assert_(set(p.children.keys()) == set(["foo", "bar"]))
1739        self.assert_(p.children["foo"].id != cid)
1740
1741        self.assert_(
1742            len(list(collections.collection_adapter(p.children))) == 2
1743        )
1744        session.flush()
1745        session.expunge_all()
1746
1747        p = session.query(Parent).get(pid)
1748        self.assert_(
1749            len(list(collections.collection_adapter(p.children))) == 2
1750        )
1751
1752        collections.collection_adapter(p.children).remove_with_event(
1753            p.children["foo"]
1754        )
1755
1756        self.assert_(
1757            len(list(collections.collection_adapter(p.children))) == 1
1758        )
1759        session.flush()
1760        session.expunge_all()
1761
1762        p = session.query(Parent).get(pid)
1763        self.assert_(
1764            len(list(collections.collection_adapter(p.children))) == 1
1765        )
1766
1767        del p.children["bar"]
1768        self.assert_(
1769            len(list(collections.collection_adapter(p.children))) == 0
1770        )
1771        session.flush()
1772        session.expunge_all()
1773
1774        p = session.query(Parent).get(pid)
1775        self.assert_(
1776            len(list(collections.collection_adapter(p.children))) == 0
1777        )
1778
1779    def _test_composite_mapped(self, collection_class):
1780        parents, children, Parent, Child = (
1781            self.tables.parents,
1782            self.tables.children,
1783            self.classes.Parent,
1784            self.classes.Child,
1785        )
1786
1787        mapper(Child, children)
1788        mapper(
1789            Parent,
1790            parents,
1791            properties={
1792                "children": relationship(
1793                    Child,
1794                    collection_class=collection_class,
1795                    cascade="all, delete-orphan",
1796                )
1797            },
1798        )
1799
1800        p = Parent()
1801        p.children[("foo", "1")] = Child("foo", "1", "value 1")
1802        p.children[("foo", "2")] = Child("foo", "2", "value 2")
1803
1804        session = create_session()
1805        session.add(p)
1806        session.flush()
1807        pid = p.id
1808        session.expunge_all()
1809
1810        p = session.query(Parent).get(pid)
1811
1812        self.assert_(
1813            set(p.children.keys()) == set([("foo", "1"), ("foo", "2")])
1814        )
1815        cid = p.children[("foo", "1")].id
1816
1817        collections.collection_adapter(p.children).append_with_event(
1818            Child("foo", "1", "newvalue")
1819        )
1820
1821        session.flush()
1822        session.expunge_all()
1823
1824        p = session.query(Parent).get(pid)
1825
1826        self.assert_(
1827            set(p.children.keys()) == set([("foo", "1"), ("foo", "2")])
1828        )
1829        self.assert_(p.children[("foo", "1")].id != cid)
1830
1831        self.assert_(
1832            len(list(collections.collection_adapter(p.children))) == 2
1833        )
1834
1835    def test_mapped_collection(self):
1836        collection_class = collections.mapped_collection(lambda c: c.a)
1837        self._test_scalar_mapped(collection_class)
1838
1839    def test_mapped_collection2(self):
1840        collection_class = collections.mapped_collection(lambda c: (c.a, c.b))
1841        self._test_composite_mapped(collection_class)
1842
1843    def test_attr_mapped_collection(self):
1844        collection_class = collections.attribute_mapped_collection("a")
1845        self._test_scalar_mapped(collection_class)
1846
1847    def test_declarative_column_mapped(self):
1848        """test that uncompiled attribute usage works with
1849        column_mapped_collection"""
1850
1851        from sqlalchemy.ext.declarative import declarative_base
1852
1853        BaseObject = declarative_base()
1854
1855        class Foo(BaseObject):
1856            __tablename__ = "foo"
1857            id = Column(Integer(), primary_key=True)
1858            bar_id = Column(Integer, ForeignKey("bar.id"))
1859
1860        for spec, obj, expected in (
1861            (Foo.id, Foo(id=3), 3),
1862            ((Foo.id, Foo.bar_id), Foo(id=3, bar_id=12), (3, 12)),
1863        ):
1864            eq_(
1865                collections.column_mapped_collection(spec)().keyfunc(obj),
1866                expected,
1867            )
1868
1869    def test_column_mapped_assertions(self):
1870        assert_raises_message(
1871            sa_exc.ArgumentError,
1872            "Column-based expression object expected "
1873            "for argument 'mapping_spec'; got: 'a'",
1874            collections.column_mapped_collection,
1875            "a",
1876        )
1877        assert_raises_message(
1878            sa_exc.ArgumentError,
1879            "Column-based expression object expected "
1880            "for argument 'mapping_spec'; got: 'a'",
1881            collections.column_mapped_collection,
1882            text("a"),
1883        )
1884
1885    def test_column_mapped_collection(self):
1886        children = self.tables.children
1887
1888        collection_class = collections.column_mapped_collection(children.c.a)
1889        self._test_scalar_mapped(collection_class)
1890
1891    def test_column_mapped_collection2(self):
1892        children = self.tables.children
1893
1894        collection_class = collections.column_mapped_collection(
1895            (children.c.a, children.c.b)
1896        )
1897        self._test_composite_mapped(collection_class)
1898
1899    def test_mixin(self):
1900        class Ordered(util.OrderedDict, collections.MappedCollection):
1901            def __init__(self):
1902                collections.MappedCollection.__init__(self, lambda v: v.a)
1903                util.OrderedDict.__init__(self)
1904
1905        collection_class = Ordered
1906        self._test_scalar_mapped(collection_class)
1907
1908    def test_mixin2(self):
1909        class Ordered2(util.OrderedDict, collections.MappedCollection):
1910            def __init__(self, keyfunc):
1911                collections.MappedCollection.__init__(self, keyfunc)
1912                util.OrderedDict.__init__(self)
1913
1914        def collection_class():
1915            return Ordered2(lambda v: (v.a, v.b))
1916
1917        self._test_composite_mapped(collection_class)
1918
1919
1920class ColumnMappedWSerialize(fixtures.MappedTest):
1921    """test the column_mapped_collection serializer against
1922    multi-table and indirect table edge cases, including
1923    serialization."""
1924
1925    run_create_tables = run_deletes = None
1926
1927    @classmethod
1928    def define_tables(cls, metadata):
1929        Table(
1930            "foo",
1931            metadata,
1932            Column("id", Integer(), primary_key=True),
1933            Column("b", String(128)),
1934        )
1935        Table(
1936            "bar",
1937            metadata,
1938            Column("id", Integer(), primary_key=True),
1939            Column("foo_id", Integer, ForeignKey("foo.id")),
1940            Column("bat_id", Integer),
1941            schema="x",
1942        )
1943
1944    @classmethod
1945    def setup_classes(cls):
1946        class Foo(cls.Basic):
1947            pass
1948
1949        class Bar(Foo):
1950            pass
1951
1952    def test_indirect_table_column_mapped(self):
1953        Foo = self.classes.Foo
1954        Bar = self.classes.Bar
1955        bar = self.tables["x.bar"]
1956        mapper(
1957            Foo, self.tables.foo, properties={"foo_id": self.tables.foo.c.id}
1958        )
1959        mapper(Bar, bar, inherits=Foo, properties={"bar_id": bar.c.id})
1960
1961        bar_spec = Bar(foo_id=1, bar_id=2, bat_id=3)
1962        self._run_test(
1963            [
1964                (Foo.foo_id, bar_spec, 1),
1965                ((Bar.bar_id, Bar.bat_id), bar_spec, (2, 3)),
1966                (Bar.foo_id, bar_spec, 1),
1967                (bar.c.id, bar_spec, 2),
1968            ]
1969        )
1970
1971    def test_selectable_column_mapped(self):
1972        from sqlalchemy import select
1973
1974        s = select([self.tables.foo]).alias()
1975        Foo = self.classes.Foo
1976        mapper(Foo, s)
1977        self._run_test([(Foo.b, Foo(b=5), 5), (s.c.b, Foo(b=5), 5)])
1978
1979    def _run_test(self, specs):
1980        from sqlalchemy.testing.util import picklers
1981
1982        for spec, obj, expected in specs:
1983            coll = collections.column_mapped_collection(spec)()
1984            eq_(coll.keyfunc(obj), expected)
1985            # ensure we do the right thing with __reduce__
1986            for loads, dumps in picklers():
1987                c2 = loads(dumps(coll))
1988                eq_(c2.keyfunc(obj), expected)
1989                c3 = loads(dumps(c2))
1990                eq_(c3.keyfunc(obj), expected)
1991
1992
1993class CustomCollectionsTest(fixtures.MappedTest):
1994    """test the integration of collections with mapped classes."""
1995
1996    @classmethod
1997    def define_tables(cls, metadata):
1998        Table(
1999            "sometable",
2000            metadata,
2001            Column(
2002                "col1",
2003                Integer,
2004                primary_key=True,
2005                test_needs_autoincrement=True,
2006            ),
2007            Column("data", String(30)),
2008        )
2009        Table(
2010            "someothertable",
2011            metadata,
2012            Column(
2013                "col1",
2014                Integer,
2015                primary_key=True,
2016                test_needs_autoincrement=True,
2017            ),
2018            Column("scol1", Integer, ForeignKey("sometable.col1")),
2019            Column("data", String(20)),
2020        )
2021
2022    def test_basic(self):
2023        someothertable, sometable = (
2024            self.tables.someothertable,
2025            self.tables.sometable,
2026        )
2027
2028        class MyList(list):
2029            pass
2030
2031        class Foo(object):
2032            pass
2033
2034        class Bar(object):
2035            pass
2036
2037        mapper(
2038            Foo,
2039            sometable,
2040            properties={"bars": relationship(Bar, collection_class=MyList)},
2041        )
2042        mapper(Bar, someothertable)
2043        f = Foo()
2044        assert isinstance(f.bars, MyList)
2045
2046    def test_lazyload(self):
2047        """test that a 'set' can be used as a collection and can lazyload."""
2048
2049        someothertable, sometable = (
2050            self.tables.someothertable,
2051            self.tables.sometable,
2052        )
2053
2054        class Foo(object):
2055            pass
2056
2057        class Bar(object):
2058            pass
2059
2060        mapper(
2061            Foo,
2062            sometable,
2063            properties={"bars": relationship(Bar, collection_class=set)},
2064        )
2065        mapper(Bar, someothertable)
2066        f = Foo()
2067        f.bars.add(Bar())
2068        f.bars.add(Bar())
2069        sess = create_session()
2070        sess.add(f)
2071        sess.flush()
2072        sess.expunge_all()
2073        f = sess.query(Foo).get(f.col1)
2074        assert len(list(f.bars)) == 2
2075        f.bars.clear()
2076
2077    def test_dict(self):
2078        """test that a 'dict' can be used as a collection and can lazyload."""
2079
2080        someothertable, sometable = (
2081            self.tables.someothertable,
2082            self.tables.sometable,
2083        )
2084
2085        class Foo(object):
2086            pass
2087
2088        class Bar(object):
2089            pass
2090
2091        class AppenderDict(dict):
2092            @collection.appender
2093            def set(self, item):
2094                self[id(item)] = item
2095
2096            @collection.remover
2097            def remove(self, item):
2098                if id(item) in self:
2099                    del self[id(item)]
2100
2101        mapper(
2102            Foo,
2103            sometable,
2104            properties={
2105                "bars": relationship(Bar, collection_class=AppenderDict)
2106            },
2107        )
2108        mapper(Bar, someothertable)
2109        f = Foo()
2110        f.bars.set(Bar())
2111        f.bars.set(Bar())
2112        sess = create_session()
2113        sess.add(f)
2114        sess.flush()
2115        sess.expunge_all()
2116        f = sess.query(Foo).get(f.col1)
2117        assert len(list(f.bars)) == 2
2118        f.bars.clear()
2119
2120    def test_dict_wrapper(self):
2121        """test that the supplied 'dict' wrapper can be used as a
2122        collection and can lazyload."""
2123
2124        someothertable, sometable = (
2125            self.tables.someothertable,
2126            self.tables.sometable,
2127        )
2128
2129        class Foo(object):
2130            pass
2131
2132        class Bar(object):
2133            def __init__(self, data):
2134                self.data = data
2135
2136        mapper(
2137            Foo,
2138            sometable,
2139            properties={
2140                "bars": relationship(
2141                    Bar,
2142                    collection_class=collections.column_mapped_collection(
2143                        someothertable.c.data
2144                    ),
2145                )
2146            },
2147        )
2148        mapper(Bar, someothertable)
2149
2150        f = Foo()
2151        col = collections.collection_adapter(f.bars)
2152        col.append_with_event(Bar("a"))
2153        col.append_with_event(Bar("b"))
2154        sess = create_session()
2155        sess.add(f)
2156        sess.flush()
2157        sess.expunge_all()
2158        f = sess.query(Foo).get(f.col1)
2159        assert len(list(f.bars)) == 2
2160
2161        strongref = list(f.bars.values())
2162        existing = set([id(b) for b in strongref])
2163
2164        col = collections.collection_adapter(f.bars)
2165        col.append_with_event(Bar("b"))
2166        f.bars["a"] = Bar("a")
2167        sess.flush()
2168        sess.expunge_all()
2169        f = sess.query(Foo).get(f.col1)
2170        assert len(list(f.bars)) == 2
2171
2172        replaced = set([id(b) for b in list(f.bars.values())])
2173        ne_(existing, replaced)
2174
2175    def test_list(self):
2176        self._test_list(list)
2177
2178    def test_list_no_setslice(self):
2179        class ListLike(object):
2180            def __init__(self):
2181                self.data = list()
2182
2183            def append(self, item):
2184                self.data.append(item)
2185
2186            def remove(self, item):
2187                self.data.remove(item)
2188
2189            def insert(self, index, item):
2190                self.data.insert(index, item)
2191
2192            def pop(self, index=-1):
2193                return self.data.pop(index)
2194
2195            def extend(self):
2196                assert False
2197
2198            def __len__(self):
2199                return len(self.data)
2200
2201            def __setitem__(self, key, value):
2202                self.data[key] = value
2203
2204            def __getitem__(self, key):
2205                return self.data[key]
2206
2207            def __delitem__(self, key):
2208                del self.data[key]
2209
2210            def __iter__(self):
2211                return iter(self.data)
2212
2213            __hash__ = object.__hash__
2214
2215            def __eq__(self, other):
2216                return self.data == other
2217
2218            def __repr__(self):
2219                return "ListLike(%s)" % repr(self.data)
2220
2221        self._test_list(ListLike)
2222
2223    def _test_list(self, listcls):
2224        someothertable, sometable = (
2225            self.tables.someothertable,
2226            self.tables.sometable,
2227        )
2228
2229        class Parent(object):
2230            pass
2231
2232        class Child(object):
2233            pass
2234
2235        mapper(
2236            Parent,
2237            sometable,
2238            properties={
2239                "children": relationship(Child, collection_class=listcls)
2240            },
2241        )
2242        mapper(Child, someothertable)
2243
2244        control = list()
2245        p = Parent()
2246
2247        o = Child()
2248        control.append(o)
2249        p.children.append(o)
2250        assert control == p.children
2251        assert control == list(p.children)
2252
2253        o = [Child(), Child(), Child(), Child()]
2254        control.extend(o)
2255        p.children.extend(o)
2256        assert control == p.children
2257        assert control == list(p.children)
2258
2259        assert control[0] == p.children[0]
2260        assert control[-1] == p.children[-1]
2261        assert control[1:3] == p.children[1:3]
2262
2263        del control[1]
2264        del p.children[1]
2265        assert control == p.children
2266        assert control == list(p.children)
2267
2268        o = [Child()]
2269        control[1:3] = o
2270
2271        p.children[1:3] = o
2272        assert control == p.children
2273        assert control == list(p.children)
2274
2275        o = [Child(), Child(), Child(), Child()]
2276        control[1:3] = o
2277        p.children[1:3] = o
2278        assert control == p.children
2279        assert control == list(p.children)
2280
2281        o = [Child(), Child(), Child(), Child()]
2282        control[-1:-2] = o
2283        p.children[-1:-2] = o
2284        assert control == p.children
2285        assert control == list(p.children)
2286
2287        o = [Child(), Child(), Child(), Child()]
2288        control[4:] = o
2289        p.children[4:] = o
2290        assert control == p.children
2291        assert control == list(p.children)
2292
2293        o = Child()
2294        control.insert(0, o)
2295        p.children.insert(0, o)
2296        assert control == p.children
2297        assert control == list(p.children)
2298
2299        o = Child()
2300        control.insert(3, o)
2301        p.children.insert(3, o)
2302        assert control == p.children
2303        assert control == list(p.children)
2304
2305        o = Child()
2306        control.insert(999, o)
2307        p.children.insert(999, o)
2308        assert control == p.children
2309        assert control == list(p.children)
2310
2311        del control[0:1]
2312        del p.children[0:1]
2313        assert control == p.children
2314        assert control == list(p.children)
2315
2316        del control[1:1]
2317        del p.children[1:1]
2318        assert control == p.children
2319        assert control == list(p.children)
2320
2321        del control[1:3]
2322        del p.children[1:3]
2323        assert control == p.children
2324        assert control == list(p.children)
2325
2326        del control[7:]
2327        del p.children[7:]
2328        assert control == p.children
2329        assert control == list(p.children)
2330
2331        assert control.pop() == p.children.pop()
2332        assert control == p.children
2333        assert control == list(p.children)
2334
2335        assert control.pop(0) == p.children.pop(0)
2336        assert control == p.children
2337        assert control == list(p.children)
2338
2339        assert control.pop(2) == p.children.pop(2)
2340        assert control == p.children
2341        assert control == list(p.children)
2342
2343        o = Child()
2344        control.insert(2, o)
2345        p.children.insert(2, o)
2346        assert control == p.children
2347        assert control == list(p.children)
2348
2349        control.remove(o)
2350        p.children.remove(o)
2351        assert control == p.children
2352        assert control == list(p.children)
2353
2354    def test_custom(self):
2355        someothertable, sometable = (
2356            self.tables.someothertable,
2357            self.tables.sometable,
2358        )
2359
2360        class Parent(object):
2361            pass
2362
2363        class Child(object):
2364            pass
2365
2366        class MyCollection(object):
2367            def __init__(self):
2368                self.data = []
2369
2370            @collection.appender
2371            def append(self, value):
2372                self.data.append(value)
2373
2374            @collection.remover
2375            def remove(self, value):
2376                self.data.remove(value)
2377
2378            @collection.iterator
2379            def __iter__(self):
2380                return iter(self.data)
2381
2382        mapper(
2383            Parent,
2384            sometable,
2385            properties={
2386                "children": relationship(Child, collection_class=MyCollection)
2387            },
2388        )
2389        mapper(Child, someothertable)
2390
2391        control = list()
2392        p1 = Parent()
2393
2394        o = Child()
2395        control.append(o)
2396        p1.children.append(o)
2397        assert control == list(p1.children)
2398
2399        o = Child()
2400        control.append(o)
2401        p1.children.append(o)
2402        assert control == list(p1.children)
2403
2404        o = Child()
2405        control.append(o)
2406        p1.children.append(o)
2407        assert control == list(p1.children)
2408
2409        sess = create_session()
2410        sess.add(p1)
2411        sess.flush()
2412        sess.expunge_all()
2413
2414        p2 = sess.query(Parent).get(p1.col1)
2415        o = list(p2.children)
2416        assert len(o) == 3
2417
2418
2419class InstrumentationTest(fixtures.ORMTest):
2420    def test_uncooperative_descriptor_in_sweep(self):
2421        class DoNotTouch(object):
2422            def __get__(self, obj, owner):
2423                raise AttributeError
2424
2425        class Touchy(list):
2426            no_touch = DoNotTouch()
2427
2428        assert "no_touch" in Touchy.__dict__
2429        assert not hasattr(Touchy, "no_touch")
2430        assert "no_touch" in dir(Touchy)
2431
2432        collections._instrument_class(Touchy)
2433
2434    def test_name_setup(self):
2435        class Base(object):
2436            @collection.iterator
2437            def base_iterate(self, x):
2438                return "base_iterate"
2439
2440            @collection.appender
2441            def base_append(self, x):
2442                return "base_append"
2443
2444            @collection.converter
2445            def base_convert(self, x):
2446                return "base_convert"
2447
2448            @collection.remover
2449            def base_remove(self, x):
2450                return "base_remove"
2451
2452        from sqlalchemy.orm.collections import _instrument_class
2453
2454        _instrument_class(Base)
2455
2456        eq_(Base._sa_remover(Base(), 5), "base_remove")
2457        eq_(Base._sa_appender(Base(), 5), "base_append")
2458        eq_(Base._sa_iterator(Base(), 5), "base_iterate")
2459        eq_(Base._sa_converter(Base(), 5), "base_convert")
2460
2461        class Sub(Base):
2462            @collection.converter
2463            def base_convert(self, x):
2464                return "sub_convert"
2465
2466            @collection.remover
2467            def sub_remove(self, x):
2468                return "sub_remove"
2469
2470        _instrument_class(Sub)
2471
2472        eq_(Sub._sa_appender(Sub(), 5), "base_append")
2473        eq_(Sub._sa_remover(Sub(), 5), "sub_remove")
2474        eq_(Sub._sa_iterator(Sub(), 5), "base_iterate")
2475        eq_(Sub._sa_converter(Sub(), 5), "sub_convert")
2476
2477    @uses_deprecated(r".*Please refer to the .*init_collection")
2478    def test_link_event(self):
2479        canary = []
2480
2481        class Collection(list):
2482            @collection.linker
2483            def _on_link(self, obj):
2484                canary.append(obj)
2485
2486        class Foo(object):
2487            pass
2488
2489        instrumentation.register_class(Foo)
2490        attributes.register_attribute(
2491            Foo, "attr", uselist=True, typecallable=Collection, useobject=True
2492        )
2493
2494        f1 = Foo()
2495        f1.attr.append(3)
2496
2497        eq_(canary, [f1.attr._sa_adapter])
2498        adapter_1 = f1.attr._sa_adapter
2499
2500        l2 = Collection()
2501        f1.attr = l2
2502        eq_(canary, [adapter_1, f1.attr._sa_adapter, None])
2503
2504    def test_referenced_by_owner(self):
2505        class Foo(object):
2506            pass
2507
2508        instrumentation.register_class(Foo)
2509        attributes.register_attribute(
2510            Foo, "attr", uselist=True, useobject=True
2511        )
2512
2513        f1 = Foo()
2514        f1.attr.append(3)
2515
2516        adapter = collections.collection_adapter(f1.attr)
2517        assert adapter._referenced_by_owner
2518
2519        f1.attr = []
2520        assert not adapter._referenced_by_owner
2521