1from sqlalchemy.testing import eq_
2import sys
3from operator import and_
4
5import sqlalchemy.orm.collections as collections
6from sqlalchemy.orm.collections import collection
7
8import sqlalchemy as sa
9from sqlalchemy import Integer, String, ForeignKey, text
10from sqlalchemy.testing.schema import Table, Column
11from sqlalchemy import util, exc as sa_exc
12from sqlalchemy.orm import create_session, mapper, relationship, \
13    attributes, instrumentation
14from sqlalchemy.testing import fixtures
15from sqlalchemy.testing import assert_raises, assert_raises_message
16
17class Canary(sa.orm.interfaces.AttributeExtension):
18    def __init__(self):
19        self.data = set()
20        self.added = set()
21        self.removed = set()
22    def append(self, obj, value, initiator):
23        assert value not in self.added
24        self.data.add(value)
25        self.added.add(value)
26        return value
27    def remove(self, obj, value, initiator):
28        assert value not in self.removed
29        self.data.remove(value)
30        self.removed.add(value)
31    def set(self, obj, value, oldvalue, initiator):
32        if isinstance(value, str):
33            value = CollectionsTest.entity_maker()
34
35        if oldvalue is not None:
36            self.remove(obj, oldvalue, None)
37        self.append(obj, value, None)
38        return value
39
40class CollectionsTest(fixtures.ORMTest):
41    class Entity(object):
42        def __init__(self, a=None, b=None, c=None):
43            self.a = a
44            self.b = b
45            self.c = c
46        def __repr__(self):
47            return str((id(self), self.a, self.b, self.c))
48
49    @classmethod
50    def setup_class(cls):
51        instrumentation.register_class(cls.Entity)
52
53    @classmethod
54    def teardown_class(cls):
55        instrumentation.unregister_class(cls.Entity)
56        super(CollectionsTest, cls).teardown_class()
57
58    _entity_id = 1
59
60    @classmethod
61    def entity_maker(cls):
62        cls._entity_id += 1
63        return cls.Entity(cls._entity_id)
64
65    @classmethod
66    def dictable_entity(cls, a=None, b=None, c=None):
67        id = cls._entity_id = (cls._entity_id + 1)
68        return cls.Entity(a or str(id), b or 'value %s' % id, c)
69
70    def _test_adapter(self, typecallable, creator=None, to_set=None):
71        if creator is None:
72            creator = self.entity_maker
73
74        class Foo(object):
75            pass
76
77        canary = Canary()
78        instrumentation.register_class(Foo)
79        attributes.register_attribute(Foo, 'attr', uselist=True,
80                                    extension=canary,
81                                   typecallable=typecallable, useobject=True)
82
83        obj = Foo()
84        adapter = collections.collection_adapter(obj.attr)
85        direct = obj.attr
86        if to_set is None:
87            to_set = lambda col: set(col)
88
89        def assert_eq():
90            self.assert_(to_set(direct) == canary.data)
91            self.assert_(set(adapter) == canary.data)
92        assert_ne = lambda: self.assert_(to_set(direct) != canary.data)
93
94        e1, e2 = creator(), creator()
95
96        adapter.append_with_event(e1)
97        assert_eq()
98
99        adapter.append_without_event(e2)
100        assert_ne()
101        canary.data.add(e2)
102        assert_eq()
103
104        adapter.remove_without_event(e2)
105        assert_ne()
106        canary.data.remove(e2)
107        assert_eq()
108
109        adapter.remove_with_event(e1)
110        assert_eq()
111
112    def _test_list(self, typecallable, creator=None):
113        if creator is None:
114            creator = self.entity_maker
115
116        class Foo(object):
117            pass
118
119        canary = Canary()
120        instrumentation.register_class(Foo)
121        attributes.register_attribute(Foo, 'attr', uselist=True,
122                                    extension=canary,
123                                   typecallable=typecallable, useobject=True)
124
125        obj = Foo()
126        adapter = collections.collection_adapter(obj.attr)
127        direct = obj.attr
128        control = list()
129
130        def assert_eq():
131            eq_(set(direct), canary.data)
132            eq_(set(adapter), canary.data)
133            eq_(direct, control)
134
135        # assume append() is available for list tests
136        e = creator()
137        direct.append(e)
138        control.append(e)
139        assert_eq()
140
141        if hasattr(direct, 'pop'):
142            direct.pop()
143            control.pop()
144            assert_eq()
145
146        if hasattr(direct, '__setitem__'):
147            e = creator()
148            direct.append(e)
149            control.append(e)
150
151            e = creator()
152            direct[0] = e
153            control[0] = e
154            assert_eq()
155
156            if util.reduce(and_, [hasattr(direct, a) for a in
157                             ('__delitem__', 'insert', '__len__')], True):
158                values = [creator(), creator(), creator(), creator()]
159                direct[slice(0, 1)] = values
160                control[slice(0, 1)] = values
161                assert_eq()
162
163                values = [creator(), creator()]
164                direct[slice(0, -1, 2)] = values
165                control[slice(0, -1, 2)] = values
166                assert_eq()
167
168                values = [creator()]
169                direct[slice(0, -1)] = values
170                control[slice(0, -1)] = values
171                assert_eq()
172
173                values = [creator(), creator(), creator()]
174                control[:] = values
175                direct[:] = values
176                def invalid():
177                    direct[slice(0, 6, 2)] = [creator()]
178                assert_raises(ValueError, invalid)
179
180        if hasattr(direct, '__delitem__'):
181            e = creator()
182            direct.append(e)
183            control.append(e)
184            del direct[-1]
185            del control[-1]
186            assert_eq()
187
188            if hasattr(direct, '__getslice__'):
189                for e in [creator(), creator(), creator(), creator()]:
190                    direct.append(e)
191                    control.append(e)
192
193                del direct[:-3]
194                del control[:-3]
195                assert_eq()
196
197                del direct[0:1]
198                del control[0:1]
199                assert_eq()
200
201                del direct[::2]
202                del control[::2]
203                assert_eq()
204
205        if hasattr(direct, 'remove'):
206            e = creator()
207            direct.append(e)
208            control.append(e)
209
210            direct.remove(e)
211            control.remove(e)
212            assert_eq()
213
214        if hasattr(direct, '__setitem__') or hasattr(direct, '__setslice__'):
215
216            values = [creator(), creator()]
217            direct[:] = values
218            control[:] = values
219            assert_eq()
220
221            # test slice assignment where
222            # slice size goes over the number of items
223            values = [creator(), creator()]
224            direct[1:3] = values
225            control[1:3] = values
226            assert_eq()
227
228            values = [creator(), creator()]
229            direct[0:1] = values
230            control[0:1] = values
231            assert_eq()
232
233            values = [creator()]
234            direct[0:] = values
235            control[0:] = values
236            assert_eq()
237
238            values = [creator()]
239            direct[:1] = values
240            control[:1] = values
241            assert_eq()
242
243            values = [creator()]
244            direct[-1::2] = values
245            control[-1::2] = values
246            assert_eq()
247
248            values = [creator()] * len(direct[1::2])
249            direct[1::2] = values
250            control[1::2] = values
251            assert_eq()
252
253            values = [creator(), creator()]
254            direct[-1:-3] = values
255            control[-1:-3] = values
256            assert_eq()
257
258            values = [creator(), creator()]
259            direct[-2:-1] = values
260            control[-2:-1] = values
261            assert_eq()
262
263            values = [creator()]
264            direct[0:0] = values
265            control[0:0] = values
266            assert_eq()
267
268
269        if hasattr(direct, '__delitem__') or hasattr(direct, '__delslice__'):
270            for i in range(1, 4):
271                e = creator()
272                direct.append(e)
273                control.append(e)
274
275            del direct[-1:]
276            del control[-1:]
277            assert_eq()
278
279            del direct[1:2]
280            del control[1:2]
281            assert_eq()
282
283            del direct[:]
284            del control[:]
285            assert_eq()
286
287        if hasattr(direct, 'clear'):
288            for i in range(1, 4):
289                e = creator()
290                direct.append(e)
291                control.append(e)
292
293            direct.clear()
294            control.clear()
295            assert_eq()
296
297        if hasattr(direct, 'extend'):
298            values = [creator(), creator(), creator()]
299
300            direct.extend(values)
301            control.extend(values)
302            assert_eq()
303
304        if hasattr(direct, '__iadd__'):
305            values = [creator(), creator(), creator()]
306
307            direct += values
308            control += values
309            assert_eq()
310
311            direct += []
312            control += []
313            assert_eq()
314
315            values = [creator(), creator()]
316            obj.attr += values
317            control += values
318            assert_eq()
319
320        if hasattr(direct, '__imul__'):
321            direct *= 2
322            control *= 2
323            assert_eq()
324
325            obj.attr *= 2
326            control *= 2
327            assert_eq()
328
329    def _test_list_bulk(self, typecallable, creator=None):
330        if creator is None:
331            creator = self.entity_maker
332
333        class Foo(object):
334            pass
335
336        canary = Canary()
337        instrumentation.register_class(Foo)
338        attributes.register_attribute(Foo, 'attr', uselist=True,
339                                    extension=canary,
340                                   typecallable=typecallable, useobject=True)
341
342        obj = Foo()
343        direct = obj.attr
344
345        e1 = creator()
346        obj.attr.append(e1)
347
348        like_me = typecallable()
349        e2 = creator()
350        like_me.append(e2)
351
352        self.assert_(obj.attr is direct)
353        obj.attr = like_me
354        self.assert_(obj.attr is not direct)
355        self.assert_(obj.attr is not like_me)
356        self.assert_(set(obj.attr) == set([e2]))
357        self.assert_(e1 in canary.removed)
358        self.assert_(e2 in canary.added)
359
360        e3 = creator()
361        real_list = [e3]
362        obj.attr = real_list
363        self.assert_(obj.attr is not real_list)
364        self.assert_(set(obj.attr) == set([e3]))
365        self.assert_(e2 in canary.removed)
366        self.assert_(e3 in canary.added)
367
368        e4 = creator()
369        try:
370            obj.attr = set([e4])
371            self.assert_(False)
372        except TypeError:
373            self.assert_(e4 not in canary.data)
374            self.assert_(e3 in canary.data)
375
376        e5 = creator()
377        e6 = creator()
378        e7 = creator()
379        obj.attr = [e5, e6, e7]
380        self.assert_(e5 in canary.added)
381        self.assert_(e6 in canary.added)
382        self.assert_(e7 in canary.added)
383
384        obj.attr = [e6, e7]
385        self.assert_(e5 in canary.removed)
386        self.assert_(e6 in canary.added)
387        self.assert_(e7 in canary.added)
388        self.assert_(e6 not in canary.removed)
389        self.assert_(e7 not in canary.removed)
390
391    def test_list(self):
392        self._test_adapter(list)
393        self._test_list(list)
394        self._test_list_bulk(list)
395
396    def test_list_setitem_with_slices(self):
397
398        # this is a "list" that has no __setslice__
399        # or __delslice__ methods.  The __setitem__
400        # and __delitem__ must therefore accept
401        # slice objects (i.e. as in py3k)
402        class ListLike(object):
403            def __init__(self):
404                self.data = list()
405            def append(self, item):
406                self.data.append(item)
407            def remove(self, item):
408                self.data.remove(item)
409            def insert(self, index, item):
410                self.data.insert(index, item)
411            def pop(self, index=-1):
412                return self.data.pop(index)
413            def extend(self):
414                assert False
415            def __len__(self):
416                return len(self.data)
417            def __setitem__(self, key, value):
418                self.data[key] = value
419            def __getitem__(self, key):
420                return self.data[key]
421            def __delitem__(self, key):
422                del self.data[key]
423            def __iter__(self):
424                return iter(self.data)
425            __hash__ = object.__hash__
426            def __eq__(self, other):
427                return self.data == other
428            def __repr__(self):
429                return 'ListLike(%s)' % repr(self.data)
430
431        self._test_adapter(ListLike)
432        self._test_list(ListLike)
433        self._test_list_bulk(ListLike)
434
435    def test_list_subclass(self):
436        class MyList(list):
437            pass
438        self._test_adapter(MyList)
439        self._test_list(MyList)
440        self._test_list_bulk(MyList)
441        self.assert_(getattr(MyList, '_sa_instrumented') == id(MyList))
442
443    def test_list_duck(self):
444        class ListLike(object):
445            def __init__(self):
446                self.data = list()
447            def append(self, item):
448                self.data.append(item)
449            def remove(self, item):
450                self.data.remove(item)
451            def insert(self, index, item):
452                self.data.insert(index, item)
453            def pop(self, index=-1):
454                return self.data.pop(index)
455            def extend(self):
456                assert False
457            def __iter__(self):
458                return iter(self.data)
459            __hash__ = object.__hash__
460            def __eq__(self, other):
461                return self.data == other
462            def __repr__(self):
463                return 'ListLike(%s)' % repr(self.data)
464
465        self._test_adapter(ListLike)
466        self._test_list(ListLike)
467        self._test_list_bulk(ListLike)
468        self.assert_(getattr(ListLike, '_sa_instrumented') == id(ListLike))
469
470    def test_list_emulates(self):
471        class ListIsh(object):
472            __emulates__ = list
473            def __init__(self):
474                self.data = list()
475            def append(self, item):
476                self.data.append(item)
477            def remove(self, item):
478                self.data.remove(item)
479            def insert(self, index, item):
480                self.data.insert(index, item)
481            def pop(self, index=-1):
482                return self.data.pop(index)
483            def extend(self):
484                assert False
485            def __iter__(self):
486                return iter(self.data)
487            __hash__ = object.__hash__
488            def __eq__(self, other):
489                return self.data == other
490            def __repr__(self):
491                return 'ListIsh(%s)' % repr(self.data)
492
493        self._test_adapter(ListIsh)
494        self._test_list(ListIsh)
495        self._test_list_bulk(ListIsh)
496        self.assert_(getattr(ListIsh, '_sa_instrumented') == id(ListIsh))
497
498    def _test_set(self, typecallable, creator=None):
499        if creator is None:
500            creator = self.entity_maker
501
502        class Foo(object):
503            pass
504
505        canary = Canary()
506        instrumentation.register_class(Foo)
507        attributes.register_attribute(Foo, 'attr', uselist=True,
508                                    extension=canary,
509                                   typecallable=typecallable, useobject=True)
510
511        obj = Foo()
512        adapter = collections.collection_adapter(obj.attr)
513        direct = obj.attr
514        control = set()
515
516        def assert_eq():
517            eq_(set(direct), canary.data)
518            eq_(set(adapter), canary.data)
519            eq_(direct, control)
520
521        def addall(*values):
522            for item in values:
523                direct.add(item)
524                control.add(item)
525            assert_eq()
526        def zap():
527            for item in list(direct):
528                direct.remove(item)
529            control.clear()
530
531        addall(creator())
532
533        e = creator()
534        addall(e)
535        addall(e)
536
537
538        if hasattr(direct, 'remove'):
539            e = creator()
540            addall(e)
541
542            direct.remove(e)
543            control.remove(e)
544            assert_eq()
545
546            e = creator()
547            try:
548                direct.remove(e)
549            except KeyError:
550                assert_eq()
551                self.assert_(e not in canary.removed)
552            else:
553                self.assert_(False)
554
555        if hasattr(direct, 'discard'):
556            e = creator()
557            addall(e)
558
559            direct.discard(e)
560            control.discard(e)
561            assert_eq()
562
563            e = creator()
564            direct.discard(e)
565            self.assert_(e not in canary.removed)
566            assert_eq()
567
568        if hasattr(direct, 'update'):
569            zap()
570            e = creator()
571            addall(e)
572
573            values = set([e, creator(), creator()])
574
575            direct.update(values)
576            control.update(values)
577            assert_eq()
578
579        if hasattr(direct, '__ior__'):
580            zap()
581            e = creator()
582            addall(e)
583
584            values = set([e, creator(), creator()])
585
586            direct |= values
587            control |= values
588            assert_eq()
589
590            # cover self-assignment short-circuit
591            values = set([e, creator(), creator()])
592            obj.attr |= values
593            control |= values
594            assert_eq()
595
596            values = frozenset([e, creator()])
597            obj.attr |= values
598            control |= values
599            assert_eq()
600
601            try:
602                direct |= [e, creator()]
603                assert False
604            except TypeError:
605                assert True
606
607        addall(creator(), creator())
608        direct.clear()
609        control.clear()
610        assert_eq()
611
612        # note: the clear test previously needs
613        # to have executed in order for this to
614        # pass in all cases; else there's the possibility
615        # of non-deterministic behavior.
616        addall(creator())
617        direct.pop()
618        control.pop()
619        assert_eq()
620
621        if hasattr(direct, 'difference_update'):
622            zap()
623            e = creator()
624            addall(creator(), creator())
625            values = set([creator()])
626
627            direct.difference_update(values)
628            control.difference_update(values)
629            assert_eq()
630            values.update(set([e, creator()]))
631            direct.difference_update(values)
632            control.difference_update(values)
633            assert_eq()
634
635        if hasattr(direct, '__isub__'):
636            zap()
637            e = creator()
638            addall(creator(), creator())
639            values = set([creator()])
640
641            direct -= values
642            control -= values
643            assert_eq()
644            values.update(set([e, creator()]))
645            direct -= values
646            control -= values
647            assert_eq()
648
649            values = set([creator()])
650            obj.attr -= values
651            control -= values
652            assert_eq()
653
654            values = frozenset([creator()])
655            obj.attr -= values
656            control -= values
657            assert_eq()
658
659            try:
660                direct -= [e, creator()]
661                assert False
662            except TypeError:
663                assert True
664
665        if hasattr(direct, 'intersection_update'):
666            zap()
667            e = creator()
668            addall(e, creator(), creator())
669            values = set(control)
670
671            direct.intersection_update(values)
672            control.intersection_update(values)
673            assert_eq()
674
675            values.update(set([e, creator()]))
676            direct.intersection_update(values)
677            control.intersection_update(values)
678            assert_eq()
679
680        if hasattr(direct, '__iand__'):
681            zap()
682            e = creator()
683            addall(e, creator(), creator())
684            values = set(control)
685
686            direct &= values
687            control &= values
688            assert_eq()
689
690            values.update(set([e, creator()]))
691            direct &= values
692            control &= values
693            assert_eq()
694
695            values.update(set([creator()]))
696            obj.attr &= values
697            control &= values
698            assert_eq()
699
700            try:
701                direct &= [e, creator()]
702                assert False
703            except TypeError:
704                assert True
705
706        if hasattr(direct, 'symmetric_difference_update'):
707            zap()
708            e = creator()
709            addall(e, creator(), creator())
710
711            values = set([e, creator()])
712            direct.symmetric_difference_update(values)
713            control.symmetric_difference_update(values)
714            assert_eq()
715
716            e = creator()
717            addall(e)
718            values = set([e])
719            direct.symmetric_difference_update(values)
720            control.symmetric_difference_update(values)
721            assert_eq()
722
723            values = set()
724            direct.symmetric_difference_update(values)
725            control.symmetric_difference_update(values)
726            assert_eq()
727
728        if hasattr(direct, '__ixor__'):
729            zap()
730            e = creator()
731            addall(e, creator(), creator())
732
733            values = set([e, creator()])
734            direct ^= values
735            control ^= values
736            assert_eq()
737
738            e = creator()
739            addall(e)
740            values = set([e])
741            direct ^= values
742            control ^= values
743            assert_eq()
744
745            values = set()
746            direct ^= values
747            control ^= values
748            assert_eq()
749
750            values = set([creator()])
751            obj.attr ^= values
752            control ^= values
753            assert_eq()
754
755            try:
756                direct ^= [e, creator()]
757                assert False
758            except TypeError:
759                assert True
760
761
762    def _test_set_bulk(self, typecallable, creator=None):
763        if creator is None:
764            creator = self.entity_maker
765
766        class Foo(object):
767            pass
768
769        canary = Canary()
770        instrumentation.register_class(Foo)
771        attributes.register_attribute(Foo, 'attr', uselist=True,
772                                    extension=canary,
773                                   typecallable=typecallable, useobject=True)
774
775        obj = Foo()
776        direct = obj.attr
777
778        e1 = creator()
779        obj.attr.add(e1)
780
781        like_me = typecallable()
782        e2 = creator()
783        like_me.add(e2)
784
785        self.assert_(obj.attr is direct)
786        obj.attr = like_me
787        self.assert_(obj.attr is not direct)
788        self.assert_(obj.attr is not like_me)
789        self.assert_(obj.attr == set([e2]))
790        self.assert_(e1 in canary.removed)
791        self.assert_(e2 in canary.added)
792
793        e3 = creator()
794        real_set = set([e3])
795        obj.attr = real_set
796        self.assert_(obj.attr is not real_set)
797        self.assert_(obj.attr == set([e3]))
798        self.assert_(e2 in canary.removed)
799        self.assert_(e3 in canary.added)
800
801        e4 = creator()
802        try:
803            obj.attr = [e4]
804            self.assert_(False)
805        except TypeError:
806            self.assert_(e4 not in canary.data)
807            self.assert_(e3 in canary.data)
808
809    def test_set(self):
810        self._test_adapter(set)
811        self._test_set(set)
812        self._test_set_bulk(set)
813
814    def test_set_subclass(self):
815        class MySet(set):
816            pass
817        self._test_adapter(MySet)
818        self._test_set(MySet)
819        self._test_set_bulk(MySet)
820        self.assert_(getattr(MySet, '_sa_instrumented') == id(MySet))
821
822    def test_set_duck(self):
823        class SetLike(object):
824            def __init__(self):
825                self.data = set()
826            def add(self, item):
827                self.data.add(item)
828            def remove(self, item):
829                self.data.remove(item)
830            def discard(self, item):
831                self.data.discard(item)
832            def clear(self):
833                self.data.clear()
834            def pop(self):
835                return self.data.pop()
836            def update(self, other):
837                self.data.update(other)
838            def __iter__(self):
839                return iter(self.data)
840            __hash__ = object.__hash__
841            def __eq__(self, other):
842                return self.data == other
843
844        self._test_adapter(SetLike)
845        self._test_set(SetLike)
846        self._test_set_bulk(SetLike)
847        self.assert_(getattr(SetLike, '_sa_instrumented') == id(SetLike))
848
849    def test_set_emulates(self):
850        class SetIsh(object):
851            __emulates__ = set
852            def __init__(self):
853                self.data = set()
854            def add(self, item):
855                self.data.add(item)
856            def remove(self, item):
857                self.data.remove(item)
858            def discard(self, item):
859                self.data.discard(item)
860            def pop(self):
861                return self.data.pop()
862            def update(self, other):
863                self.data.update(other)
864            def __iter__(self):
865                return iter(self.data)
866            def clear(self):
867                self.data.clear()
868            __hash__ = object.__hash__
869            def __eq__(self, other):
870                return self.data == other
871
872        self._test_adapter(SetIsh)
873        self._test_set(SetIsh)
874        self._test_set_bulk(SetIsh)
875        self.assert_(getattr(SetIsh, '_sa_instrumented') == id(SetIsh))
876
877    def _test_dict(self, typecallable, creator=None):
878        if creator is None:
879            creator = self.dictable_entity
880
881        class Foo(object):
882            pass
883
884        canary = Canary()
885        instrumentation.register_class(Foo)
886        attributes.register_attribute(Foo, 'attr', uselist=True,
887                                    extension=canary,
888                                   typecallable=typecallable, useobject=True)
889
890        obj = Foo()
891        adapter = collections.collection_adapter(obj.attr)
892        direct = obj.attr
893        control = dict()
894
895        def assert_eq():
896            self.assert_(set(direct.values()) == canary.data)
897            self.assert_(set(adapter) == canary.data)
898            self.assert_(direct == control)
899
900        def addall(*values):
901            for item in values:
902                direct.set(item)
903                control[item.a] = item
904            assert_eq()
905        def zap():
906            for item in list(adapter):
907                direct.remove(item)
908            control.clear()
909
910        # assume an 'set' method is available for tests
911        addall(creator())
912
913        if hasattr(direct, '__setitem__'):
914            e = creator()
915            direct[e.a] = e
916            control[e.a] = e
917            assert_eq()
918
919            e = creator(e.a, e.b)
920            direct[e.a] = e
921            control[e.a] = e
922            assert_eq()
923
924        if hasattr(direct, '__delitem__'):
925            e = creator()
926            addall(e)
927
928            del direct[e.a]
929            del control[e.a]
930            assert_eq()
931
932            e = creator()
933            try:
934                del direct[e.a]
935            except KeyError:
936                self.assert_(e not in canary.removed)
937
938        if hasattr(direct, 'clear'):
939            addall(creator(), creator(), creator())
940
941            direct.clear()
942            control.clear()
943            assert_eq()
944
945            direct.clear()
946            control.clear()
947            assert_eq()
948
949        if hasattr(direct, 'pop'):
950            e = creator()
951            addall(e)
952
953            direct.pop(e.a)
954            control.pop(e.a)
955            assert_eq()
956
957            e = creator()
958            try:
959                direct.pop(e.a)
960            except KeyError:
961                self.assert_(e not in canary.removed)
962
963        if hasattr(direct, 'popitem'):
964            zap()
965            e = creator()
966            addall(e)
967
968            direct.popitem()
969            control.popitem()
970            assert_eq()
971
972        if hasattr(direct, 'setdefault'):
973            e = creator()
974
975            val_a = direct.setdefault(e.a, e)
976            val_b = control.setdefault(e.a, e)
977            assert_eq()
978            self.assert_(val_a is val_b)
979
980            val_a = direct.setdefault(e.a, e)
981            val_b = control.setdefault(e.a, e)
982            assert_eq()
983            self.assert_(val_a is val_b)
984
985        if hasattr(direct, 'update'):
986            e = creator()
987            d = dict([(ee.a, ee) for ee in [e, creator(), creator()]])
988            addall(e, creator())
989
990            direct.update(d)
991            control.update(d)
992            assert_eq()
993
994            kw = dict([(ee.a, ee) for ee in [e, creator()]])
995            direct.update(**kw)
996            control.update(**kw)
997            assert_eq()
998
999    def _test_dict_bulk(self, typecallable, creator=None):
1000        if creator is None:
1001            creator = self.dictable_entity
1002
1003        class Foo(object):
1004            pass
1005
1006        canary = Canary()
1007        instrumentation.register_class(Foo)
1008        attributes.register_attribute(Foo, 'attr', uselist=True,
1009                                    extension=canary,
1010                                   typecallable=typecallable, useobject=True)
1011
1012        obj = Foo()
1013        direct = obj.attr
1014
1015        e1 = creator()
1016        collections.collection_adapter(direct).append_with_event(e1)
1017
1018        like_me = typecallable()
1019        e2 = creator()
1020        like_me.set(e2)
1021
1022        self.assert_(obj.attr is direct)
1023        obj.attr = like_me
1024        self.assert_(obj.attr is not direct)
1025        self.assert_(obj.attr is not like_me)
1026        self.assert_(
1027                set(collections.collection_adapter(obj.attr)) == set([e2]))
1028        self.assert_(e1 in canary.removed)
1029        self.assert_(e2 in canary.added)
1030
1031
1032        # key validity on bulk assignment is a basic feature of
1033        # MappedCollection but is not present in basic, @converter-less
1034        # dict collections.
1035        e3 = creator()
1036        if isinstance(obj.attr, collections.MappedCollection):
1037            real_dict = dict(badkey=e3)
1038            try:
1039                obj.attr = real_dict
1040                self.assert_(False)
1041            except TypeError:
1042                pass
1043            self.assert_(obj.attr is not real_dict)
1044            self.assert_('badkey' not in obj.attr)
1045            eq_(set(collections.collection_adapter(obj.attr)),
1046                              set([e2]))
1047            self.assert_(e3 not in canary.added)
1048        else:
1049            real_dict = dict(keyignored1=e3)
1050            obj.attr = real_dict
1051            self.assert_(obj.attr is not real_dict)
1052            self.assert_('keyignored1' not in obj.attr)
1053            eq_(set(collections.collection_adapter(obj.attr)),
1054                              set([e3]))
1055            self.assert_(e2 in canary.removed)
1056            self.assert_(e3 in canary.added)
1057
1058        obj.attr = typecallable()
1059        eq_(list(collections.collection_adapter(obj.attr)), [])
1060
1061        e4 = creator()
1062        try:
1063            obj.attr = [e4]
1064            self.assert_(False)
1065        except TypeError:
1066            self.assert_(e4 not in canary.data)
1067
1068    def test_dict(self):
1069        assert_raises_message(
1070            sa_exc.ArgumentError,
1071            'Type InstrumentedDict must elect an appender '
1072                'method to be a collection class',
1073            self._test_adapter, dict, self.dictable_entity,
1074                               to_set=lambda c: set(c.values())
1075        )
1076
1077        assert_raises_message(
1078            sa_exc.ArgumentError,
1079            'Type InstrumentedDict must elect an appender method '
1080                'to be a collection class',
1081            self._test_dict, dict
1082        )
1083
1084    def test_dict_subclass(self):
1085        class MyDict(dict):
1086            @collection.appender
1087            @collection.internally_instrumented
1088            def set(self, item, _sa_initiator=None):
1089                self.__setitem__(item.a, item, _sa_initiator=_sa_initiator)
1090            @collection.remover
1091            @collection.internally_instrumented
1092            def _remove(self, item, _sa_initiator=None):
1093                self.__delitem__(item.a, _sa_initiator=_sa_initiator)
1094
1095        self._test_adapter(MyDict, self.dictable_entity,
1096                           to_set=lambda c: set(c.values()))
1097        self._test_dict(MyDict)
1098        self._test_dict_bulk(MyDict)
1099        self.assert_(getattr(MyDict, '_sa_instrumented') == id(MyDict))
1100
1101    def test_dict_subclass2(self):
1102        class MyEasyDict(collections.MappedCollection):
1103            def __init__(self):
1104                super(MyEasyDict, self).__init__(lambda e: e.a)
1105
1106        self._test_adapter(MyEasyDict, self.dictable_entity,
1107                           to_set=lambda c: set(c.values()))
1108        self._test_dict(MyEasyDict)
1109        self._test_dict_bulk(MyEasyDict)
1110        self.assert_(getattr(MyEasyDict, '_sa_instrumented') == id(MyEasyDict))
1111
1112    def test_dict_subclass3(self):
1113        class MyOrdered(util.OrderedDict, collections.MappedCollection):
1114            def __init__(self):
1115                collections.MappedCollection.__init__(self, lambda e: e.a)
1116                util.OrderedDict.__init__(self)
1117
1118        self._test_adapter(MyOrdered, self.dictable_entity,
1119                           to_set=lambda c: set(c.values()))
1120        self._test_dict(MyOrdered)
1121        self._test_dict_bulk(MyOrdered)
1122        self.assert_(getattr(MyOrdered, '_sa_instrumented') == id(MyOrdered))
1123
1124    def test_dict_subclass4(self):
1125        # tests #2654
1126        class MyDict(collections.MappedCollection):
1127            def __init__(self):
1128                super(MyDict, self).__init__(lambda value: "k%d" % value)
1129
1130            @collection.converter
1131            def _convert(self, dictlike):
1132                for key, value in dictlike.items():
1133                    yield value + 5
1134
1135        class Foo(object):
1136            pass
1137
1138        canary = Canary()
1139
1140        instrumentation.register_class(Foo)
1141        attributes.register_attribute(Foo, 'attr', uselist=True,
1142                                    extension=canary,
1143                                   typecallable=MyDict, useobject=True)
1144
1145        f = Foo()
1146        f.attr = {"k1": 1, "k2": 2}
1147
1148        eq_(f.attr, {'k7': 7, 'k6': 6})
1149
1150    def test_dict_duck(self):
1151        class DictLike(object):
1152            def __init__(self):
1153                self.data = dict()
1154
1155            @collection.appender
1156            @collection.replaces(1)
1157            def set(self, item):
1158                current = self.data.get(item.a, None)
1159                self.data[item.a] = item
1160                return current
1161            @collection.remover
1162            def _remove(self, item):
1163                del self.data[item.a]
1164            def __setitem__(self, key, value):
1165                self.data[key] = value
1166            def __getitem__(self, key):
1167                return self.data[key]
1168            def __delitem__(self, key):
1169                del self.data[key]
1170            def values(self):
1171                return list(self.data.values())
1172            def __contains__(self, key):
1173                return key in self.data
1174            @collection.iterator
1175            def itervalues(self):
1176                return iter(self.data.values())
1177            __hash__ = object.__hash__
1178            def __eq__(self, other):
1179                return self.data == other
1180            def __repr__(self):
1181                return 'DictLike(%s)' % repr(self.data)
1182
1183        self._test_adapter(DictLike, self.dictable_entity,
1184                           to_set=lambda c: set(c.values()))
1185        self._test_dict(DictLike)
1186        self._test_dict_bulk(DictLike)
1187        self.assert_(getattr(DictLike, '_sa_instrumented') == id(DictLike))
1188
1189    def test_dict_emulates(self):
1190        class DictIsh(object):
1191            __emulates__ = dict
1192            def __init__(self):
1193                self.data = dict()
1194
1195            @collection.appender
1196            @collection.replaces(1)
1197            def set(self, item):
1198                current = self.data.get(item.a, None)
1199                self.data[item.a] = item
1200                return current
1201            @collection.remover
1202            def _remove(self, item):
1203                del self.data[item.a]
1204            def __setitem__(self, key, value):
1205                self.data[key] = value
1206            def __getitem__(self, key):
1207                return self.data[key]
1208            def __delitem__(self, key):
1209                del self.data[key]
1210            def values(self):
1211                return list(self.data.values())
1212            def __contains__(self, key):
1213                return key in self.data
1214            @collection.iterator
1215            def itervalues(self):
1216                return iter(self.data.values())
1217            __hash__ = object.__hash__
1218            def __eq__(self, other):
1219                return self.data == other
1220            def __repr__(self):
1221                return 'DictIsh(%s)' % repr(self.data)
1222
1223        self._test_adapter(DictIsh, self.dictable_entity,
1224                           to_set=lambda c: set(c.values()))
1225        self._test_dict(DictIsh)
1226        self._test_dict_bulk(DictIsh)
1227        self.assert_(getattr(DictIsh, '_sa_instrumented') == id(DictIsh))
1228
1229    def _test_object(self, typecallable, creator=None):
1230        if creator is None:
1231            creator = self.entity_maker
1232
1233        class Foo(object):
1234            pass
1235
1236        canary = Canary()
1237        instrumentation.register_class(Foo)
1238        attributes.register_attribute(Foo, 'attr', uselist=True,
1239                                    extension=canary,
1240                                   typecallable=typecallable, useobject=True)
1241
1242        obj = Foo()
1243        adapter = collections.collection_adapter(obj.attr)
1244        direct = obj.attr
1245        control = set()
1246
1247        def assert_eq():
1248            self.assert_(set(direct) == canary.data)
1249            self.assert_(set(adapter) == canary.data)
1250            self.assert_(direct == control)
1251
1252        # There is no API for object collections.  We'll make one up
1253        # for the purposes of the test.
1254        e = creator()
1255        direct.push(e)
1256        control.add(e)
1257        assert_eq()
1258
1259        direct.zark(e)
1260        control.remove(e)
1261        assert_eq()
1262
1263        e = creator()
1264        direct.maybe_zark(e)
1265        control.discard(e)
1266        assert_eq()
1267
1268        e = creator()
1269        direct.push(e)
1270        control.add(e)
1271        assert_eq()
1272
1273        e = creator()
1274        direct.maybe_zark(e)
1275        control.discard(e)
1276        assert_eq()
1277
1278    def test_object_duck(self):
1279        class MyCollection(object):
1280            def __init__(self):
1281                self.data = set()
1282            @collection.appender
1283            def push(self, item):
1284                self.data.add(item)
1285            @collection.remover
1286            def zark(self, item):
1287                self.data.remove(item)
1288            @collection.removes_return()
1289            def maybe_zark(self, item):
1290                if item in self.data:
1291                    self.data.remove(item)
1292                    return item
1293            @collection.iterator
1294            def __iter__(self):
1295                return iter(self.data)
1296            __hash__ = object.__hash__
1297            def __eq__(self, other):
1298                return self.data == other
1299
1300        self._test_adapter(MyCollection)
1301        self._test_object(MyCollection)
1302        self.assert_(getattr(MyCollection, '_sa_instrumented') ==
1303                     id(MyCollection))
1304
1305    def test_object_emulates(self):
1306        class MyCollection2(object):
1307            __emulates__ = None
1308            def __init__(self):
1309                self.data = set()
1310            # looks like a list
1311            def append(self, item):
1312                assert False
1313            @collection.appender
1314            def push(self, item):
1315                self.data.add(item)
1316            @collection.remover
1317            def zark(self, item):
1318                self.data.remove(item)
1319            @collection.removes_return()
1320            def maybe_zark(self, item):
1321                if item in self.data:
1322                    self.data.remove(item)
1323                    return item
1324            @collection.iterator
1325            def __iter__(self):
1326                return iter(self.data)
1327            __hash__ = object.__hash__
1328            def __eq__(self, other):
1329                return self.data == other
1330
1331        self._test_adapter(MyCollection2)
1332        self._test_object(MyCollection2)
1333        self.assert_(getattr(MyCollection2, '_sa_instrumented') ==
1334                     id(MyCollection2))
1335
1336    def test_recipes(self):
1337        class Custom(object):
1338            def __init__(self):
1339                self.data = []
1340            @collection.appender
1341            @collection.adds('entity')
1342            def put(self, entity):
1343                self.data.append(entity)
1344
1345            @collection.remover
1346            @collection.removes(1)
1347            def remove(self, entity):
1348                self.data.remove(entity)
1349
1350            @collection.adds(1)
1351            def push(self, *args):
1352                self.data.append(args[0])
1353
1354            @collection.removes('entity')
1355            def yank(self, entity, arg):
1356                self.data.remove(entity)
1357
1358            @collection.replaces(2)
1359            def replace(self, arg, entity, **kw):
1360                self.data.insert(0, entity)
1361                return self.data.pop()
1362
1363            @collection.removes_return()
1364            def pop(self, key):
1365                return self.data.pop()
1366
1367            @collection.iterator
1368            def __iter__(self):
1369                return iter(self.data)
1370
1371        class Foo(object):
1372            pass
1373        canary = Canary()
1374        instrumentation.register_class(Foo)
1375        attributes.register_attribute(Foo, 'attr', uselist=True,
1376                                    extension=canary,
1377                                   typecallable=Custom, useobject=True)
1378
1379        obj = Foo()
1380        adapter = collections.collection_adapter(obj.attr)
1381        direct = obj.attr
1382        control = list()
1383        def assert_eq():
1384            self.assert_(set(direct) == canary.data)
1385            self.assert_(set(adapter) == canary.data)
1386            self.assert_(list(direct) == control)
1387        creator = self.entity_maker
1388
1389        e1 = creator()
1390        direct.put(e1)
1391        control.append(e1)
1392        assert_eq()
1393
1394        e2 = creator()
1395        direct.put(entity=e2)
1396        control.append(e2)
1397        assert_eq()
1398
1399        direct.remove(e2)
1400        control.remove(e2)
1401        assert_eq()
1402
1403        direct.remove(entity=e1)
1404        control.remove(e1)
1405        assert_eq()
1406
1407        e3 = creator()
1408        direct.push(e3)
1409        control.append(e3)
1410        assert_eq()
1411
1412        direct.yank(e3, 'blah')
1413        control.remove(e3)
1414        assert_eq()
1415
1416        e4, e5, e6, e7 = creator(), creator(), creator(), creator()
1417        direct.put(e4)
1418        direct.put(e5)
1419        control.append(e4)
1420        control.append(e5)
1421
1422        dr1 = direct.replace('foo', e6, bar='baz')
1423        control.insert(0, e6)
1424        cr1 = control.pop()
1425        assert_eq()
1426        self.assert_(dr1 is cr1)
1427
1428        dr2 = direct.replace(arg=1, entity=e7)
1429        control.insert(0, e7)
1430        cr2 = control.pop()
1431        assert_eq()
1432        self.assert_(dr2 is cr2)
1433
1434        dr3 = direct.pop('blah')
1435        cr3 = control.pop()
1436        assert_eq()
1437        self.assert_(dr3 is cr3)
1438
1439    def test_lifecycle(self):
1440        class Foo(object):
1441            pass
1442
1443        canary = Canary()
1444        creator = self.entity_maker
1445        instrumentation.register_class(Foo)
1446        attributes.register_attribute(Foo, 'attr', uselist=True,
1447                                extension=canary, useobject=True)
1448
1449        obj = Foo()
1450        col1 = obj.attr
1451
1452        e1 = creator()
1453        obj.attr.append(e1)
1454
1455        e2 = creator()
1456        bulk1 = [e2]
1457        # empty & sever col1 from obj
1458        obj.attr = bulk1
1459        self.assert_(len(col1) == 0)
1460        self.assert_(len(canary.data) == 1)
1461        self.assert_(obj.attr is not col1)
1462        self.assert_(obj.attr is not bulk1)
1463        self.assert_(obj.attr == bulk1)
1464
1465        e3 = creator()
1466        col1.append(e3)
1467        self.assert_(e3 not in canary.data)
1468        self.assert_(collections.collection_adapter(col1) is None)
1469
1470        obj.attr[0] = e3
1471        self.assert_(e3 in canary.data)
1472
1473class DictHelpersTest(fixtures.MappedTest):
1474
1475    @classmethod
1476    def define_tables(cls, metadata):
1477        Table('parents', metadata,
1478              Column('id', Integer, primary_key=True,
1479                        test_needs_autoincrement=True),
1480              Column('label', String(128)))
1481        Table('children', metadata,
1482              Column('id', Integer, primary_key=True,
1483                        test_needs_autoincrement=True),
1484              Column('parent_id', Integer, ForeignKey('parents.id'),
1485                     nullable=False),
1486              Column('a', String(128)),
1487              Column('b', String(128)),
1488              Column('c', String(128)))
1489
1490    @classmethod
1491    def setup_classes(cls):
1492        class Parent(cls.Basic):
1493            def __init__(self, label=None):
1494                self.label = label
1495
1496        class Child(cls.Basic):
1497            def __init__(self, a=None, b=None, c=None):
1498                self.a = a
1499                self.b = b
1500                self.c = c
1501
1502    def _test_scalar_mapped(self, collection_class):
1503        parents, children, Parent, Child = (self.tables.parents,
1504                                self.tables.children,
1505                                self.classes.Parent,
1506                                self.classes.Child)
1507
1508        mapper(Child, children)
1509        mapper(Parent, parents, properties={
1510            'children': relationship(Child, collection_class=collection_class,
1511                                 cascade="all, delete-orphan")})
1512
1513        p = Parent()
1514        p.children['foo'] = Child('foo', 'value')
1515        p.children['bar'] = Child('bar', 'value')
1516        session = create_session()
1517        session.add(p)
1518        session.flush()
1519        pid = p.id
1520        session.expunge_all()
1521
1522        p = session.query(Parent).get(pid)
1523
1524
1525        eq_(set(p.children.keys()), set(['foo', 'bar']))
1526        cid = p.children['foo'].id
1527
1528        collections.collection_adapter(p.children).append_with_event(
1529            Child('foo', 'newvalue'))
1530
1531        session.flush()
1532        session.expunge_all()
1533
1534        p = session.query(Parent).get(pid)
1535
1536        self.assert_(set(p.children.keys()) == set(['foo', 'bar']))
1537        self.assert_(p.children['foo'].id != cid)
1538
1539        self.assert_(
1540                len(list(collections.collection_adapter(p.children))) == 2)
1541        session.flush()
1542        session.expunge_all()
1543
1544        p = session.query(Parent).get(pid)
1545        self.assert_(
1546                len(list(collections.collection_adapter(p.children))) == 2)
1547
1548        collections.collection_adapter(p.children).remove_with_event(
1549            p.children['foo'])
1550
1551        self.assert_(
1552                len(list(collections.collection_adapter(p.children))) == 1)
1553        session.flush()
1554        session.expunge_all()
1555
1556        p = session.query(Parent).get(pid)
1557        self.assert_(
1558                len(list(collections.collection_adapter(p.children))) == 1)
1559
1560        del p.children['bar']
1561        self.assert_(
1562                len(list(collections.collection_adapter(p.children))) == 0)
1563        session.flush()
1564        session.expunge_all()
1565
1566        p = session.query(Parent).get(pid)
1567        self.assert_(
1568                len(list(collections.collection_adapter(p.children))) == 0)
1569
1570
1571    def _test_composite_mapped(self, collection_class):
1572        parents, children, Parent, Child = (self.tables.parents,
1573                                self.tables.children,
1574                                self.classes.Parent,
1575                                self.classes.Child)
1576
1577        mapper(Child, children)
1578        mapper(Parent, parents, properties={
1579            'children': relationship(Child, collection_class=collection_class,
1580                                 cascade="all, delete-orphan")
1581            })
1582
1583        p = Parent()
1584        p.children[('foo', '1')] = Child('foo', '1', 'value 1')
1585        p.children[('foo', '2')] = Child('foo', '2', 'value 2')
1586
1587        session = create_session()
1588        session.add(p)
1589        session.flush()
1590        pid = p.id
1591        session.expunge_all()
1592
1593        p = session.query(Parent).get(pid)
1594
1595        self.assert_(
1596                set(p.children.keys()) == set([('foo', '1'), ('foo', '2')]))
1597        cid = p.children[('foo', '1')].id
1598
1599        collections.collection_adapter(p.children).append_with_event(
1600            Child('foo', '1', 'newvalue'))
1601
1602        session.flush()
1603        session.expunge_all()
1604
1605        p = session.query(Parent).get(pid)
1606
1607        self.assert_(
1608                set(p.children.keys()) == set([('foo', '1'), ('foo', '2')]))
1609        self.assert_(p.children[('foo', '1')].id != cid)
1610
1611        self.assert_(
1612                len(list(collections.collection_adapter(p.children))) == 2)
1613
1614    def test_mapped_collection(self):
1615        collection_class = collections.mapped_collection(lambda c: c.a)
1616        self._test_scalar_mapped(collection_class)
1617
1618    def test_mapped_collection2(self):
1619        collection_class = collections.mapped_collection(lambda c: (c.a, c.b))
1620        self._test_composite_mapped(collection_class)
1621
1622    def test_attr_mapped_collection(self):
1623        collection_class = collections.attribute_mapped_collection('a')
1624        self._test_scalar_mapped(collection_class)
1625
1626    def test_declarative_column_mapped(self):
1627        """test that uncompiled attribute usage works with
1628        column_mapped_collection"""
1629
1630        from sqlalchemy.ext.declarative import declarative_base
1631
1632        BaseObject = declarative_base()
1633
1634        class Foo(BaseObject):
1635            __tablename__ = "foo"
1636            id = Column(Integer(), primary_key=True)
1637            bar_id = Column(Integer, ForeignKey('bar.id'))
1638
1639        for spec, obj, expected in (
1640            (Foo.id, Foo(id=3), 3),
1641            ((Foo.id, Foo.bar_id), Foo(id=3, bar_id=12), (3, 12))
1642        ):
1643            eq_(
1644                collections.column_mapped_collection(spec)().keyfunc(obj),
1645                expected
1646            )
1647
1648    def test_column_mapped_assertions(self):
1649        assert_raises_message(sa_exc.ArgumentError,
1650                              "Column-based expression object expected "
1651                              "for argument 'mapping_spec'; got: 'a'",
1652                              collections.column_mapped_collection, 'a')
1653        assert_raises_message(sa_exc.ArgumentError,
1654                              "Column-based expression object expected "
1655                              "for argument 'mapping_spec'; got: 'a'",
1656                              collections.column_mapped_collection,
1657                              text('a'))
1658
1659
1660    def test_column_mapped_collection(self):
1661        children = self.tables.children
1662
1663        collection_class = collections.column_mapped_collection(
1664            children.c.a)
1665        self._test_scalar_mapped(collection_class)
1666
1667    def test_column_mapped_collection2(self):
1668        children = self.tables.children
1669
1670        collection_class = collections.column_mapped_collection(
1671            (children.c.a, children.c.b))
1672        self._test_composite_mapped(collection_class)
1673
1674    def test_mixin(self):
1675        class Ordered(util.OrderedDict, collections.MappedCollection):
1676            def __init__(self):
1677                collections.MappedCollection.__init__(self, lambda v: v.a)
1678                util.OrderedDict.__init__(self)
1679        collection_class = Ordered
1680        self._test_scalar_mapped(collection_class)
1681
1682    def test_mixin2(self):
1683        class Ordered2(util.OrderedDict, collections.MappedCollection):
1684            def __init__(self, keyfunc):
1685                collections.MappedCollection.__init__(self, keyfunc)
1686                util.OrderedDict.__init__(self)
1687        collection_class = lambda: Ordered2(lambda v: (v.a, v.b))
1688        self._test_composite_mapped(collection_class)
1689
1690class ColumnMappedWSerialize(fixtures.MappedTest):
1691    """test the column_mapped_collection serializer against
1692    multi-table and indirect table edge cases, including
1693    serialization."""
1694
1695    run_create_tables = run_deletes = None
1696
1697    @classmethod
1698    def define_tables(cls, metadata):
1699        Table('foo', metadata,
1700            Column('id', Integer(), primary_key=True),
1701            Column('b', String(128))
1702        )
1703        Table('bar', metadata,
1704            Column('id', Integer(), primary_key=True),
1705            Column('foo_id', Integer, ForeignKey('foo.id')),
1706            Column('bat_id', Integer),
1707            schema="x"
1708        )
1709    @classmethod
1710    def setup_classes(cls):
1711        class Foo(cls.Basic):
1712            pass
1713        class Bar(Foo):
1714            pass
1715
1716    def test_indirect_table_column_mapped(self):
1717        Foo = self.classes.Foo
1718        Bar = self.classes.Bar
1719        bar = self.tables["x.bar"]
1720        mapper(Foo, self.tables.foo, properties={
1721            "foo_id": self.tables.foo.c.id
1722        })
1723        mapper(Bar, bar, inherits=Foo, properties={
1724            "bar_id": bar.c.id,
1725        })
1726
1727        bar_spec = Bar(foo_id=1, bar_id=2, bat_id=3)
1728        self._run_test([
1729            (Foo.foo_id, bar_spec, 1),
1730            ((Bar.bar_id, Bar.bat_id), bar_spec, (2, 3)),
1731            (Bar.foo_id, bar_spec, 1),
1732            (bar.c.id, bar_spec, 2),
1733        ])
1734
1735    def test_selectable_column_mapped(self):
1736        from sqlalchemy import select
1737        s = select([self.tables.foo]).alias()
1738        Foo = self.classes.Foo
1739        mapper(Foo, s)
1740        self._run_test([
1741            (Foo.b, Foo(b=5), 5),
1742            (s.c.b, Foo(b=5), 5)
1743        ])
1744
1745    def _run_test(self, specs):
1746        from sqlalchemy.testing.util import picklers
1747        for spec, obj, expected in specs:
1748            coll = collections.column_mapped_collection(spec)()
1749            eq_(
1750                coll.keyfunc(obj),
1751                expected
1752            )
1753            # ensure we do the right thing with __reduce__
1754            for loads, dumps in picklers():
1755                c2 = loads(dumps(coll))
1756                eq_(c2.keyfunc(obj), expected)
1757                c3 = loads(dumps(c2))
1758                eq_(c3.keyfunc(obj), expected)
1759
1760class CustomCollectionsTest(fixtures.MappedTest):
1761    """test the integration of collections with mapped classes."""
1762
1763    @classmethod
1764    def define_tables(cls, metadata):
1765        Table('sometable', metadata,
1766              Column('col1', Integer, primary_key=True,
1767                                            test_needs_autoincrement=True),
1768              Column('data', String(30)))
1769        Table('someothertable', metadata,
1770              Column('col1', Integer, primary_key=True,
1771                                            test_needs_autoincrement=True),
1772              Column('scol1', Integer,
1773                     ForeignKey('sometable.col1')),
1774              Column('data', String(20)))
1775
1776    def test_basic(self):
1777        someothertable, sometable = self.tables.someothertable, \
1778                                            self.tables.sometable
1779
1780        class MyList(list):
1781            pass
1782        class Foo(object):
1783            pass
1784        class Bar(object):
1785            pass
1786
1787        mapper(Foo, sometable, properties={
1788            'bars': relationship(Bar, collection_class=MyList)
1789        })
1790        mapper(Bar, someothertable)
1791        f = Foo()
1792        assert isinstance(f.bars, MyList)
1793
1794    def test_lazyload(self):
1795        """test that a 'set' can be used as a collection and can lazyload."""
1796
1797        someothertable, sometable = self.tables.someothertable, \
1798                                            self.tables.sometable
1799
1800        class Foo(object):
1801            pass
1802        class Bar(object):
1803            pass
1804        mapper(Foo, sometable, properties={
1805            'bars': relationship(Bar, collection_class=set)
1806        })
1807        mapper(Bar, someothertable)
1808        f = Foo()
1809        f.bars.add(Bar())
1810        f.bars.add(Bar())
1811        sess = create_session()
1812        sess.add(f)
1813        sess.flush()
1814        sess.expunge_all()
1815        f = sess.query(Foo).get(f.col1)
1816        assert len(list(f.bars)) == 2
1817        f.bars.clear()
1818
1819    def test_dict(self):
1820        """test that a 'dict' can be used as a collection and can lazyload."""
1821
1822        someothertable, sometable = self.tables.someothertable, \
1823                                            self.tables.sometable
1824
1825
1826        class Foo(object):
1827            pass
1828        class Bar(object):
1829            pass
1830        class AppenderDict(dict):
1831            @collection.appender
1832            def set(self, item):
1833                self[id(item)] = item
1834            @collection.remover
1835            def remove(self, item):
1836                if id(item) in self:
1837                    del self[id(item)]
1838
1839        mapper(Foo, sometable, properties={
1840            'bars': relationship(Bar, collection_class=AppenderDict)
1841        })
1842        mapper(Bar, someothertable)
1843        f = Foo()
1844        f.bars.set(Bar())
1845        f.bars.set(Bar())
1846        sess = create_session()
1847        sess.add(f)
1848        sess.flush()
1849        sess.expunge_all()
1850        f = sess.query(Foo).get(f.col1)
1851        assert len(list(f.bars)) == 2
1852        f.bars.clear()
1853
1854    def test_dict_wrapper(self):
1855        """test that the supplied 'dict' wrapper can be used as a
1856        collection and can lazyload."""
1857
1858        someothertable, sometable = self.tables.someothertable, \
1859                                            self.tables.sometable
1860
1861
1862        class Foo(object):
1863            pass
1864        class Bar(object):
1865            def __init__(self, data): self.data = data
1866
1867        mapper(Foo, sometable, properties={
1868            'bars':relationship(Bar,
1869                collection_class=collections.column_mapped_collection(
1870                    someothertable.c.data))
1871        })
1872        mapper(Bar, someothertable)
1873
1874        f = Foo()
1875        col = collections.collection_adapter(f.bars)
1876        col.append_with_event(Bar('a'))
1877        col.append_with_event(Bar('b'))
1878        sess = create_session()
1879        sess.add(f)
1880        sess.flush()
1881        sess.expunge_all()
1882        f = sess.query(Foo).get(f.col1)
1883        assert len(list(f.bars)) == 2
1884
1885        existing = set([id(b) for b in list(f.bars.values())])
1886
1887        col = collections.collection_adapter(f.bars)
1888        col.append_with_event(Bar('b'))
1889        f.bars['a'] = Bar('a')
1890        sess.flush()
1891        sess.expunge_all()
1892        f = sess.query(Foo).get(f.col1)
1893        assert len(list(f.bars)) == 2
1894
1895        replaced = set([id(b) for b in list(f.bars.values())])
1896        self.assert_(existing != replaced)
1897
1898    def test_list(self):
1899        self._test_list(list)
1900
1901    def test_list_no_setslice(self):
1902        class ListLike(object):
1903            def __init__(self):
1904                self.data = list()
1905            def append(self, item):
1906                self.data.append(item)
1907            def remove(self, item):
1908                self.data.remove(item)
1909            def insert(self, index, item):
1910                self.data.insert(index, item)
1911            def pop(self, index=-1):
1912                return self.data.pop(index)
1913            def extend(self):
1914                assert False
1915            def __len__(self):
1916                return len(self.data)
1917            def __setitem__(self, key, value):
1918                self.data[key] = value
1919            def __getitem__(self, key):
1920                return self.data[key]
1921            def __delitem__(self, key):
1922                del self.data[key]
1923            def __iter__(self):
1924                return iter(self.data)
1925            __hash__ = object.__hash__
1926            def __eq__(self, other):
1927                return self.data == other
1928            def __repr__(self):
1929                return 'ListLike(%s)' % repr(self.data)
1930
1931        self._test_list(ListLike)
1932
1933    def _test_list(self, listcls):
1934        someothertable, sometable = self.tables.someothertable, \
1935                                        self.tables.sometable
1936
1937        class Parent(object):
1938            pass
1939        class Child(object):
1940            pass
1941
1942        mapper(Parent, sometable, properties={
1943            'children': relationship(Child, collection_class=listcls)
1944        })
1945        mapper(Child, someothertable)
1946
1947        control = list()
1948        p = Parent()
1949
1950        o = Child()
1951        control.append(o)
1952        p.children.append(o)
1953        assert control == p.children
1954        assert control == list(p.children)
1955
1956        o = [Child(), Child(), Child(), Child()]
1957        control.extend(o)
1958        p.children.extend(o)
1959        assert control == p.children
1960        assert control == list(p.children)
1961
1962        assert control[0] == p.children[0]
1963        assert control[-1] == p.children[-1]
1964        assert control[1:3] == p.children[1:3]
1965
1966        del control[1]
1967        del p.children[1]
1968        assert control == p.children
1969        assert control == list(p.children)
1970
1971        o = [Child()]
1972        control[1:3] = o
1973
1974        p.children[1:3] = o
1975        assert control == p.children
1976        assert control == list(p.children)
1977
1978        o = [Child(), Child(), Child(), Child()]
1979        control[1:3] = o
1980        p.children[1:3] = o
1981        assert control == p.children
1982        assert control == list(p.children)
1983
1984        o = [Child(), Child(), Child(), Child()]
1985        control[-1:-2] = o
1986        p.children[-1:-2] = o
1987        assert control == p.children
1988        assert control == list(p.children)
1989
1990        o = [Child(), Child(), Child(), Child()]
1991        control[4:] = o
1992        p.children[4:] = o
1993        assert control == p.children
1994        assert control == list(p.children)
1995
1996        o = Child()
1997        control.insert(0, o)
1998        p.children.insert(0, o)
1999        assert control == p.children
2000        assert control == list(p.children)
2001
2002        o = Child()
2003        control.insert(3, o)
2004        p.children.insert(3, o)
2005        assert control == p.children
2006        assert control == list(p.children)
2007
2008        o = Child()
2009        control.insert(999, o)
2010        p.children.insert(999, o)
2011        assert control == p.children
2012        assert control == list(p.children)
2013
2014        del control[0:1]
2015        del p.children[0:1]
2016        assert control == p.children
2017        assert control == list(p.children)
2018
2019        del control[1:1]
2020        del p.children[1:1]
2021        assert control == p.children
2022        assert control == list(p.children)
2023
2024        del control[1:3]
2025        del p.children[1:3]
2026        assert control == p.children
2027        assert control == list(p.children)
2028
2029        del control[7:]
2030        del p.children[7:]
2031        assert control == p.children
2032        assert control == list(p.children)
2033
2034        assert control.pop() == p.children.pop()
2035        assert control == p.children
2036        assert control == list(p.children)
2037
2038        assert control.pop(0) == p.children.pop(0)
2039        assert control == p.children
2040        assert control == list(p.children)
2041
2042        assert control.pop(2) == p.children.pop(2)
2043        assert control == p.children
2044        assert control == list(p.children)
2045
2046        o = Child()
2047        control.insert(2, o)
2048        p.children.insert(2, o)
2049        assert control == p.children
2050        assert control == list(p.children)
2051
2052        control.remove(o)
2053        p.children.remove(o)
2054        assert control == p.children
2055        assert control == list(p.children)
2056
2057    def test_custom(self):
2058        someothertable, sometable = self.tables.someothertable, \
2059                                        self.tables.sometable
2060
2061        class Parent(object):
2062            pass
2063        class Child(object):
2064            pass
2065
2066        class MyCollection(object):
2067            def __init__(self):
2068                self.data = []
2069            @collection.appender
2070            def append(self, value):
2071                self.data.append(value)
2072            @collection.remover
2073            def remove(self, value):
2074                self.data.remove(value)
2075            @collection.iterator
2076            def __iter__(self):
2077                return iter(self.data)
2078
2079        mapper(Parent, sometable, properties={
2080            'children': relationship(Child, collection_class=MyCollection)
2081        })
2082        mapper(Child, someothertable)
2083
2084        control = list()
2085        p1 = Parent()
2086
2087        o = Child()
2088        control.append(o)
2089        p1.children.append(o)
2090        assert control == list(p1.children)
2091
2092        o = Child()
2093        control.append(o)
2094        p1.children.append(o)
2095        assert control == list(p1.children)
2096
2097        o = Child()
2098        control.append(o)
2099        p1.children.append(o)
2100        assert control == list(p1.children)
2101
2102        sess = create_session()
2103        sess.add(p1)
2104        sess.flush()
2105        sess.expunge_all()
2106
2107        p2 = sess.query(Parent).get(p1.col1)
2108        o = list(p2.children)
2109        assert len(o) == 3
2110
2111
2112class InstrumentationTest(fixtures.ORMTest):
2113    def test_uncooperative_descriptor_in_sweep(self):
2114        class DoNotTouch(object):
2115            def __get__(self, obj, owner):
2116                raise AttributeError
2117
2118        class Touchy(list):
2119            no_touch = DoNotTouch()
2120
2121        assert 'no_touch' in Touchy.__dict__
2122        assert not hasattr(Touchy, 'no_touch')
2123        assert 'no_touch' in dir(Touchy)
2124
2125        collections._instrument_class(Touchy)
2126
2127    def test_name_setup(self):
2128
2129        class Base(object):
2130            @collection.iterator
2131            def base_iterate(self, x):
2132                return "base_iterate"
2133
2134            @collection.appender
2135            def base_append(self, x):
2136                return "base_append"
2137
2138            @collection.converter
2139            def base_convert(self, x):
2140                return "base_convert"
2141
2142            @collection.remover
2143            def base_remove(self, x):
2144                return "base_remove"
2145
2146
2147        from sqlalchemy.orm.collections import _instrument_class
2148        _instrument_class(Base)
2149
2150        eq_(Base._sa_remover(Base(), 5), "base_remove")
2151        eq_(Base._sa_appender(Base(), 5), "base_append")
2152        eq_(Base._sa_iterator(Base(), 5), "base_iterate")
2153        eq_(Base._sa_converter(Base(), 5), "base_convert")
2154
2155        class Sub(Base):
2156            @collection.converter
2157            def base_convert(self, x):
2158                return "sub_convert"
2159
2160            @collection.remover
2161            def sub_remove(self, x):
2162                return "sub_remove"
2163        _instrument_class(Sub)
2164
2165        eq_(Sub._sa_appender(Sub(), 5), "base_append")
2166        eq_(Sub._sa_remover(Sub(), 5), "sub_remove")
2167        eq_(Sub._sa_iterator(Sub(), 5), "base_iterate")
2168        eq_(Sub._sa_converter(Sub(), 5), "sub_convert")
2169
2170    def test_link_event(self):
2171        canary = []
2172        class Collection(list):
2173            @collection.linker
2174            def _on_link(self, obj):
2175                canary.append(obj)
2176
2177        class Foo(object):
2178            pass
2179
2180        instrumentation.register_class(Foo)
2181        attributes.register_attribute(Foo, 'attr', uselist=True,
2182                                   typecallable=Collection, useobject=True)
2183
2184        f1 = Foo()
2185        f1.attr.append(3)
2186
2187        eq_(canary, [f1.attr._sa_adapter])
2188        adapter_1 = f1.attr._sa_adapter
2189
2190        l2 = Collection()
2191        f1.attr = l2
2192        eq_(canary, [adapter_1, f1.attr._sa_adapter, None])
2193
2194    def test_referenced_by_owner(self):
2195
2196        class Foo(object):
2197            pass
2198
2199        instrumentation.register_class(Foo)
2200        attributes.register_attribute(
2201            Foo, 'attr', uselist=True, useobject=True)
2202
2203        f1 = Foo()
2204        f1.attr.append(3)
2205
2206        adapter = collections.collection_adapter(f1.attr)
2207        assert adapter._referenced_by_owner
2208
2209        f1.attr = []
2210        assert not adapter._referenced_by_owner
2211
2212
2213
2214
2215
2216