1from sqlalchemy.testing import eq_, assert_raises
2import copy
3import pickle
4
5from sqlalchemy import *
6from sqlalchemy.orm import *
7from sqlalchemy.orm.collections import collection, attribute_mapped_collection
8from sqlalchemy.ext.associationproxy import *
9from sqlalchemy.ext.associationproxy import _AssociationList
10from sqlalchemy.testing import assert_raises_message
11from sqlalchemy.testing.util import gc_collect
12from sqlalchemy.testing import fixtures, AssertsCompiledSQL
13from sqlalchemy import testing
14from sqlalchemy.testing.schema import Table, Column
15from sqlalchemy.testing.mock import Mock, call
16from sqlalchemy.testing.assertions import expect_warnings
17
18class DictCollection(dict):
19    @collection.appender
20    def append(self, obj):
21        self[obj.foo] = obj
22    @collection.remover
23    def remove(self, obj):
24        del self[obj.foo]
25
26
27class SetCollection(set):
28    pass
29
30
31class ListCollection(list):
32    pass
33
34
35class ObjectCollection(object):
36    def __init__(self):
37        self.values = list()
38    @collection.appender
39    def append(self, obj):
40        self.values.append(obj)
41    @collection.remover
42    def remove(self, obj):
43        self.values.remove(obj)
44    def __iter__(self):
45        return iter(self.values)
46
47
48class _CollectionOperations(fixtures.TestBase):
49    def setup(self):
50        collection_class = self.collection_class
51
52        metadata = MetaData(testing.db)
53
54        parents_table = Table('Parent', metadata,
55                              Column('id', Integer, primary_key=True,
56                                     test_needs_autoincrement=True),
57                              Column('name', String(128)))
58        children_table = Table('Children', metadata,
59                               Column('id', Integer, primary_key=True,
60                                      test_needs_autoincrement=True),
61                               Column('parent_id', Integer,
62                                      ForeignKey('Parent.id')),
63                               Column('foo', String(128)),
64                               Column('name', String(128)))
65
66        class Parent(object):
67            children = association_proxy('_children', 'name')
68
69            def __init__(self, name):
70                self.name = name
71
72        class Child(object):
73            if collection_class and issubclass(collection_class, dict):
74                def __init__(self, foo, name):
75                    self.foo = foo
76                    self.name = name
77            else:
78                def __init__(self, name):
79                    self.name = name
80
81        mapper(Parent, parents_table, properties={
82            '_children': relationship(Child, lazy='joined',
83                                  collection_class=collection_class)})
84        mapper(Child, children_table)
85
86        metadata.create_all()
87
88        self.metadata = metadata
89        self.session = create_session()
90        self.Parent, self.Child = Parent, Child
91
92    def teardown(self):
93        self.metadata.drop_all()
94
95    def roundtrip(self, obj):
96        if obj not in self.session:
97            self.session.add(obj)
98        self.session.flush()
99        id, type_ = obj.id, type(obj)
100        self.session.expunge_all()
101        return self.session.query(type_).get(id)
102
103    def _test_sequence_ops(self):
104        Parent, Child = self.Parent, self.Child
105
106        p1 = Parent('P1')
107
108        self.assert_(not p1._children)
109        self.assert_(not p1.children)
110
111        ch = Child('regular')
112        p1._children.append(ch)
113
114        self.assert_(ch in p1._children)
115        self.assert_(len(p1._children) == 1)
116
117        self.assert_(p1.children)
118        self.assert_(len(p1.children) == 1)
119        self.assert_(ch not in p1.children)
120        self.assert_('regular' in p1.children)
121
122        p1.children.append('proxied')
123
124        self.assert_('proxied' in p1.children)
125        self.assert_('proxied' not in p1._children)
126        self.assert_(len(p1.children) == 2)
127        self.assert_(len(p1._children) == 2)
128
129        self.assert_(p1._children[0].name == 'regular')
130        self.assert_(p1._children[1].name == 'proxied')
131
132        del p1._children[1]
133
134        self.assert_(len(p1._children) == 1)
135        self.assert_(len(p1.children) == 1)
136        self.assert_(p1._children[0] == ch)
137
138        del p1.children[0]
139
140        self.assert_(len(p1._children) == 0)
141        self.assert_(len(p1.children) == 0)
142
143        p1.children = ['a', 'b', 'c']
144        self.assert_(len(p1._children) == 3)
145        self.assert_(len(p1.children) == 3)
146
147        del ch
148        p1 = self.roundtrip(p1)
149
150        self.assert_(len(p1._children) == 3)
151        self.assert_(len(p1.children) == 3)
152
153        popped = p1.children.pop()
154        self.assert_(len(p1.children) == 2)
155        self.assert_(popped not in p1.children)
156        p1 = self.roundtrip(p1)
157        self.assert_(len(p1.children) == 2)
158        self.assert_(popped not in p1.children)
159
160        p1.children[1] = 'changed-in-place'
161        self.assert_(p1.children[1] == 'changed-in-place')
162        inplace_id = p1._children[1].id
163        p1 = self.roundtrip(p1)
164        self.assert_(p1.children[1] == 'changed-in-place')
165        assert p1._children[1].id == inplace_id
166
167        p1.children.append('changed-in-place')
168        self.assert_(p1.children.count('changed-in-place') == 2)
169
170        p1.children.remove('changed-in-place')
171        self.assert_(p1.children.count('changed-in-place') == 1)
172
173        p1 = self.roundtrip(p1)
174        self.assert_(p1.children.count('changed-in-place') == 1)
175
176        p1._children = []
177        self.assert_(len(p1.children) == 0)
178
179        after = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
180        p1.children = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
181        self.assert_(len(p1.children) == 10)
182        self.assert_([c.name for c in p1._children] == after)
183
184        p1.children[2:6] = ['x'] * 4
185        after = ['a', 'b', 'x', 'x', 'x', 'x', 'g', 'h', 'i', 'j']
186        self.assert_(p1.children == after)
187        self.assert_([c.name for c in p1._children] == after)
188
189        p1.children[2:6] = ['y']
190        after = ['a', 'b', 'y', 'g', 'h', 'i', 'j']
191        self.assert_(p1.children == after)
192        self.assert_([c.name for c in p1._children] == after)
193
194        p1.children[2:3] = ['z'] * 4
195        after = ['a', 'b', 'z', 'z', 'z', 'z', 'g', 'h', 'i', 'j']
196        self.assert_(p1.children == after)
197        self.assert_([c.name for c in p1._children] == after)
198
199        p1.children[2::2] = ['O'] * 4
200        after = ['a', 'b', 'O', 'z', 'O', 'z', 'O', 'h', 'O', 'j']
201        self.assert_(p1.children == after)
202        self.assert_([c.name for c in p1._children] == after)
203
204        assert_raises(TypeError, set, [p1.children])
205
206        p1.children *= 0
207        after = []
208        self.assert_(p1.children == after)
209        self.assert_([c.name for c in p1._children] == after)
210
211        p1.children += ['a', 'b']
212        after = ['a', 'b']
213        self.assert_(p1.children == after)
214        self.assert_([c.name for c in p1._children] == after)
215
216        p1.children[:] = ['d', 'e']
217        after = ['d', 'e']
218        self.assert_(p1.children == after)
219        self.assert_([c.name for c in p1._children] == after)
220
221        p1.children[:] = ['a', 'b']
222
223        p1.children += ['c']
224        after = ['a', 'b', 'c']
225        self.assert_(p1.children == after)
226        self.assert_([c.name for c in p1._children] == after)
227
228        p1.children *= 1
229        after = ['a', 'b', 'c']
230        self.assert_(p1.children == after)
231        self.assert_([c.name for c in p1._children] == after)
232
233        p1.children *= 2
234        after = ['a', 'b', 'c', 'a', 'b', 'c']
235        self.assert_(p1.children == after)
236        self.assert_([c.name for c in p1._children] == after)
237
238        p1.children = ['a']
239        after = ['a']
240        self.assert_(p1.children == after)
241        self.assert_([c.name for c in p1._children] == after)
242
243        self.assert_((p1.children * 2) == ['a', 'a'])
244        self.assert_((2 * p1.children) == ['a', 'a'])
245        self.assert_((p1.children * 0) == [])
246        self.assert_((0 * p1.children) == [])
247
248        self.assert_((p1.children + ['b']) == ['a', 'b'])
249        self.assert_((['b'] + p1.children) == ['b', 'a'])
250
251        try:
252            p1.children + 123
253            assert False
254        except TypeError:
255            assert True
256
257class DefaultTest(_CollectionOperations):
258    collection_class = None
259
260    def test_sequence_ops(self):
261        self._test_sequence_ops()
262
263
264class ListTest(_CollectionOperations):
265    collection_class = list
266
267    def test_sequence_ops(self):
268        self._test_sequence_ops()
269
270
271class CustomDictTest(_CollectionOperations):
272    collection_class = DictCollection
273
274    def test_mapping_ops(self):
275        Parent, Child = self.Parent, self.Child
276
277        p1 = Parent('P1')
278
279        self.assert_(not p1._children)
280        self.assert_(not p1.children)
281
282        ch = Child('a', 'regular')
283        p1._children.append(ch)
284
285        self.assert_(ch in list(p1._children.values()))
286        self.assert_(len(p1._children) == 1)
287
288        self.assert_(p1.children)
289        self.assert_(len(p1.children) == 1)
290        self.assert_(ch not in p1.children)
291        self.assert_('a' in p1.children)
292        self.assert_(p1.children['a'] == 'regular')
293        self.assert_(p1._children['a'] == ch)
294
295        p1.children['b'] = 'proxied'
296
297        self.assert_('proxied' in list(p1.children.values()))
298        self.assert_('b' in p1.children)
299        self.assert_('proxied' not in p1._children)
300        self.assert_(len(p1.children) == 2)
301        self.assert_(len(p1._children) == 2)
302
303        self.assert_(p1._children['a'].name == 'regular')
304        self.assert_(p1._children['b'].name == 'proxied')
305
306        del p1._children['b']
307
308        self.assert_(len(p1._children) == 1)
309        self.assert_(len(p1.children) == 1)
310        self.assert_(p1._children['a'] == ch)
311
312        del p1.children['a']
313
314        self.assert_(len(p1._children) == 0)
315        self.assert_(len(p1.children) == 0)
316
317        p1.children = {'d': 'v d', 'e': 'v e', 'f': 'v f'}
318        self.assert_(len(p1._children) == 3)
319        self.assert_(len(p1.children) == 3)
320
321        self.assert_(set(p1.children) == set(['d', 'e', 'f']))
322
323        del ch
324        p1 = self.roundtrip(p1)
325        self.assert_(len(p1._children) == 3)
326        self.assert_(len(p1.children) == 3)
327
328        p1.children['e'] = 'changed-in-place'
329        self.assert_(p1.children['e'] == 'changed-in-place')
330        inplace_id = p1._children['e'].id
331        p1 = self.roundtrip(p1)
332        self.assert_(p1.children['e'] == 'changed-in-place')
333        self.assert_(p1._children['e'].id == inplace_id)
334
335        p1._children = {}
336        self.assert_(len(p1.children) == 0)
337
338        try:
339            p1._children = []
340            self.assert_(False)
341        except TypeError:
342            self.assert_(True)
343
344        try:
345            p1._children = None
346            self.assert_(False)
347        except TypeError:
348            self.assert_(True)
349
350        assert_raises(TypeError, set, [p1.children])
351
352
353class SetTest(_CollectionOperations):
354    collection_class = set
355
356    def test_set_operations(self):
357        Parent, Child = self.Parent, self.Child
358
359        p1 = Parent('P1')
360
361        self.assert_(not p1._children)
362        self.assert_(not p1.children)
363
364        ch1 = Child('regular')
365        p1._children.add(ch1)
366
367        self.assert_(ch1 in p1._children)
368        self.assert_(len(p1._children) == 1)
369
370        self.assert_(p1.children)
371        self.assert_(len(p1.children) == 1)
372        self.assert_(ch1 not in p1.children)
373        self.assert_('regular' in p1.children)
374
375        p1.children.add('proxied')
376
377        self.assert_('proxied' in p1.children)
378        self.assert_('proxied' not in p1._children)
379        self.assert_(len(p1.children) == 2)
380        self.assert_(len(p1._children) == 2)
381
382        self.assert_(set([o.name for o in p1._children]) ==
383                     set(['regular', 'proxied']))
384
385        ch2 = None
386        for o in p1._children:
387            if o.name == 'proxied':
388                ch2 = o
389                break
390
391        p1._children.remove(ch2)
392
393        self.assert_(len(p1._children) == 1)
394        self.assert_(len(p1.children) == 1)
395        self.assert_(p1._children == set([ch1]))
396
397        p1.children.remove('regular')
398
399        self.assert_(len(p1._children) == 0)
400        self.assert_(len(p1.children) == 0)
401
402        p1.children = ['a', 'b', 'c']
403        self.assert_(len(p1._children) == 3)
404        self.assert_(len(p1.children) == 3)
405
406        del ch1
407        p1 = self.roundtrip(p1)
408
409        self.assert_(len(p1._children) == 3)
410        self.assert_(len(p1.children) == 3)
411
412        self.assert_('a' in p1.children)
413        self.assert_('b' in p1.children)
414        self.assert_('d' not in p1.children)
415
416        self.assert_(p1.children == set(['a', 'b', 'c']))
417
418        assert_raises(
419            KeyError,
420            p1.children.remove, "d"
421        )
422
423        self.assert_(len(p1.children) == 3)
424        p1.children.discard('d')
425        self.assert_(len(p1.children) == 3)
426        p1 = self.roundtrip(p1)
427        self.assert_(len(p1.children) == 3)
428
429        popped = p1.children.pop()
430        self.assert_(len(p1.children) == 2)
431        self.assert_(popped not in p1.children)
432        p1 = self.roundtrip(p1)
433        self.assert_(len(p1.children) == 2)
434        self.assert_(popped not in p1.children)
435
436        p1.children = ['a', 'b', 'c']
437        p1 = self.roundtrip(p1)
438        self.assert_(p1.children == set(['a', 'b', 'c']))
439
440        p1.children.discard('b')
441        p1 = self.roundtrip(p1)
442        self.assert_(p1.children == set(['a', 'c']))
443
444        p1.children.remove('a')
445        p1 = self.roundtrip(p1)
446        self.assert_(p1.children == set(['c']))
447
448        p1._children = set()
449        self.assert_(len(p1.children) == 0)
450
451        try:
452            p1._children = []
453            self.assert_(False)
454        except TypeError:
455            self.assert_(True)
456
457        try:
458            p1._children = None
459            self.assert_(False)
460        except TypeError:
461            self.assert_(True)
462
463        assert_raises(TypeError, set, [p1.children])
464
465
466    def test_set_comparisons(self):
467        Parent, Child = self.Parent, self.Child
468
469        p1 = Parent('P1')
470        p1.children = ['a', 'b', 'c']
471        control = set(['a', 'b', 'c'])
472
473        for other in (set(['a', 'b', 'c']), set(['a', 'b', 'c', 'd']),
474                      set(['a']), set(['a', 'b']),
475                      set(['c', 'd']), set(['e', 'f', 'g']),
476                      set()):
477
478            eq_(p1.children.union(other),
479                             control.union(other))
480            eq_(p1.children.difference(other),
481                             control.difference(other))
482            eq_((p1.children - other),
483                             (control - other))
484            eq_(p1.children.intersection(other),
485                             control.intersection(other))
486            eq_(p1.children.symmetric_difference(other),
487                             control.symmetric_difference(other))
488            eq_(p1.children.issubset(other),
489                             control.issubset(other))
490            eq_(p1.children.issuperset(other),
491                             control.issuperset(other))
492
493            self.assert_((p1.children == other) == (control == other))
494            self.assert_((p1.children != other) == (control != other))
495            self.assert_((p1.children < other) == (control < other))
496            self.assert_((p1.children <= other) == (control <= other))
497            self.assert_((p1.children > other) == (control > other))
498            self.assert_((p1.children >= other) == (control >= other))
499
500    def test_set_mutation(self):
501        Parent, Child = self.Parent, self.Child
502
503        # mutations
504        for op in ('update', 'intersection_update',
505                   'difference_update', 'symmetric_difference_update'):
506            for base in (['a', 'b', 'c'], []):
507                for other in (set(['a', 'b', 'c']), set(['a', 'b', 'c', 'd']),
508                              set(['a']), set(['a', 'b']),
509                              set(['c', 'd']), set(['e', 'f', 'g']),
510                              set()):
511                    p = Parent('p')
512                    p.children = base[:]
513                    control = set(base[:])
514
515                    getattr(p.children, op)(other)
516                    getattr(control, op)(other)
517                    try:
518                        self.assert_(p.children == control)
519                    except:
520                        print('Test %s.%s(%s):' % (set(base), op, other))
521                        print('want', repr(control))
522                        print('got', repr(p.children))
523                        raise
524
525                    p = self.roundtrip(p)
526
527                    try:
528                        self.assert_(p.children == control)
529                    except:
530                        print('Test %s.%s(%s):' % (base, op, other))
531                        print('want', repr(control))
532                        print('got', repr(p.children))
533                        raise
534
535        # in-place mutations
536        for op in ('|=', '-=', '&=', '^='):
537            for base in (['a', 'b', 'c'], []):
538                for other in (set(['a', 'b', 'c']), set(['a', 'b', 'c', 'd']),
539                              set(['a']), set(['a', 'b']),
540                              set(['c', 'd']), set(['e', 'f', 'g']),
541                              frozenset(['e', 'f', 'g']),
542                              set()):
543                    p = Parent('p')
544                    p.children = base[:]
545                    control = set(base[:])
546
547                    exec("p.children %s other" % op)
548                    exec("control %s other" % op)
549
550                    try:
551                        self.assert_(p.children == control)
552                    except:
553                        print('Test %s %s %s:' % (set(base), op, other))
554                        print('want', repr(control))
555                        print('got', repr(p.children))
556                        raise
557
558                    p = self.roundtrip(p)
559
560                    try:
561                        self.assert_(p.children == control)
562                    except:
563                        print('Test %s %s %s:' % (base, op, other))
564                        print('want', repr(control))
565                        print('got', repr(p.children))
566                        raise
567
568
569class CustomSetTest(SetTest):
570    collection_class = SetCollection
571
572class CustomObjectTest(_CollectionOperations):
573    collection_class = ObjectCollection
574
575    def test_basic(self):
576        Parent, Child = self.Parent, self.Child
577
578        p = Parent('p1')
579        self.assert_(len(list(p.children)) == 0)
580
581        p.children.append('child')
582        self.assert_(len(list(p.children)) == 1)
583
584        p = self.roundtrip(p)
585        self.assert_(len(list(p.children)) == 1)
586
587        # We didn't provide an alternate _AssociationList implementation
588        # for our ObjectCollection, so indexing will fail.
589        assert_raises(
590            TypeError,
591            p.children.__getitem__, 1
592        )
593
594class ProxyFactoryTest(ListTest):
595    def setup(self):
596        metadata = MetaData(testing.db)
597
598        parents_table = Table('Parent', metadata,
599                              Column('id', Integer, primary_key=True,
600                                     test_needs_autoincrement=True),
601                              Column('name', String(128)))
602        children_table = Table('Children', metadata,
603                               Column('id', Integer, primary_key=True,
604                                      test_needs_autoincrement=True),
605                               Column('parent_id', Integer,
606                                      ForeignKey('Parent.id')),
607                               Column('foo', String(128)),
608                               Column('name', String(128)))
609
610        class CustomProxy(_AssociationList):
611            def __init__(
612                self,
613                lazy_collection,
614                creator,
615                value_attr,
616                parent,
617                ):
618                getter, setter = parent._default_getset(lazy_collection)
619                _AssociationList.__init__(
620                    self,
621                    lazy_collection,
622                    creator,
623                    getter,
624                    setter,
625                    parent,
626                    )
627
628        class Parent(object):
629            children = association_proxy('_children', 'name',
630                        proxy_factory=CustomProxy,
631                        proxy_bulk_set=CustomProxy.extend
632                    )
633
634            def __init__(self, name):
635                self.name = name
636
637        class Child(object):
638            def __init__(self, name):
639                self.name = name
640
641        mapper(Parent, parents_table, properties={
642            '_children': relationship(Child, lazy='joined',
643                                  collection_class=list)})
644        mapper(Child, children_table)
645
646        metadata.create_all()
647
648        self.metadata = metadata
649        self.session = create_session()
650        self.Parent, self.Child = Parent, Child
651
652    def test_sequence_ops(self):
653        self._test_sequence_ops()
654
655
656class ScalarTest(fixtures.TestBase):
657    @testing.provide_metadata
658    def test_scalar_proxy(self):
659        metadata = self.metadata
660
661        parents_table = Table('Parent', metadata,
662                              Column('id', Integer, primary_key=True,
663                                     test_needs_autoincrement=True),
664                              Column('name', String(128)))
665        children_table = Table('Children', metadata,
666                               Column('id', Integer, primary_key=True,
667                                      test_needs_autoincrement=True),
668                               Column('parent_id', Integer,
669                                      ForeignKey('Parent.id')),
670                               Column('foo', String(128)),
671                               Column('bar', String(128)),
672                               Column('baz', String(128)))
673
674        class Parent(object):
675            foo = association_proxy('child', 'foo')
676            bar = association_proxy('child', 'bar',
677                                    creator=lambda v: Child(bar=v))
678            baz = association_proxy('child', 'baz',
679                                    creator=lambda v: Child(baz=v))
680
681            def __init__(self, name):
682                self.name = name
683
684        class Child(object):
685            def __init__(self, **kw):
686                for attr in kw:
687                    setattr(self, attr, kw[attr])
688
689        mapper(Parent, parents_table, properties={
690            'child': relationship(Child, lazy='joined',
691                              backref='parent', uselist=False)})
692        mapper(Child, children_table)
693
694        metadata.create_all()
695        session = create_session()
696
697        def roundtrip(obj):
698            if obj not in session:
699                session.add(obj)
700            session.flush()
701            id, type_ = obj.id, type(obj)
702            session.expunge_all()
703            return session.query(type_).get(id)
704
705        p = Parent('p')
706
707        eq_(p.child, None)
708        eq_(p.foo, None)
709
710        p.child = Child(foo='a', bar='b', baz='c')
711
712        self.assert_(p.foo == 'a')
713        self.assert_(p.bar == 'b')
714        self.assert_(p.baz == 'c')
715
716        p.bar = 'x'
717        self.assert_(p.foo == 'a')
718        self.assert_(p.bar == 'x')
719        self.assert_(p.baz == 'c')
720
721        p = roundtrip(p)
722
723        self.assert_(p.foo == 'a')
724        self.assert_(p.bar == 'x')
725        self.assert_(p.baz == 'c')
726
727        p.child = None
728
729        eq_(p.foo, None)
730
731        # Bogus creator for this scalar type
732        assert_raises(
733            TypeError,
734            setattr, p, "foo", "zzz"
735        )
736
737        p.bar = 'yyy'
738
739        self.assert_(p.foo is None)
740        self.assert_(p.bar == 'yyy')
741        self.assert_(p.baz is None)
742
743        del p.child
744
745        p = roundtrip(p)
746
747        self.assert_(p.child is None)
748
749        p.baz = 'xxx'
750
751        self.assert_(p.foo is None)
752        self.assert_(p.bar is None)
753        self.assert_(p.baz == 'xxx')
754
755        p = roundtrip(p)
756
757        self.assert_(p.foo is None)
758        self.assert_(p.bar is None)
759        self.assert_(p.baz == 'xxx')
760
761        # Ensure an immediate __set__ works.
762        p2 = Parent('p2')
763        p2.bar = 'quux'
764
765    @testing.provide_metadata
766    def test_empty_scalars(self):
767        metadata = self.metadata
768
769        a = Table('a', metadata,
770                Column('id', Integer, primary_key=True),
771                Column('name', String(50))
772            )
773        a2b = Table('a2b', metadata,
774            Column('id', Integer, primary_key=True),
775            Column('id_a', Integer, ForeignKey('a.id')),
776            Column('id_b', Integer, ForeignKey('b.id')),
777            Column('name', String(50))
778        )
779        b = Table('b', metadata,
780            Column('id', Integer, primary_key=True),
781            Column('name', String(50))
782        )
783        class A(object):
784            a2b_name = association_proxy("a2b_single", "name")
785            b_single = association_proxy("a2b_single", "b")
786
787        class A2B(object):
788            pass
789
790        class B(object):
791            pass
792
793        mapper(A, a, properties=dict(
794            a2b_single=relationship(A2B, uselist=False)
795        ))
796
797        mapper(A2B, a2b, properties=dict(
798            b=relationship(B)
799        ))
800        mapper(B, b)
801
802        a1 = A()
803        assert a1.a2b_name is None
804        assert a1.b_single is None
805
806    def test_custom_getset(self):
807        metadata = MetaData()
808        p = Table('p', metadata,
809                              Column('id', Integer, primary_key=True),
810                              Column('cid', Integer, ForeignKey('c.id')))
811        c = Table('c', metadata,
812                               Column('id', Integer, primary_key=True),
813                               Column('foo', String(128)))
814
815        get = Mock()
816        set_ = Mock()
817        class Parent(object):
818            foo = association_proxy('child', 'foo',
819                    getset_factory=lambda cc, parent: (get, set_))
820
821        class Child(object):
822            def __init__(self, foo):
823                self.foo = foo
824
825        mapper(Parent, p, properties={'child': relationship(Child)})
826        mapper(Child, c)
827
828        p1 = Parent()
829
830        eq_(p1.foo, get(None))
831        p1.child = child = Child(foo='x')
832        eq_(p1.foo, get(child))
833        p1.foo = "y"
834        eq_(set_.mock_calls, [call(child, "y")])
835
836
837
838class LazyLoadTest(fixtures.TestBase):
839    def setup(self):
840        metadata = MetaData(testing.db)
841
842        parents_table = Table('Parent', metadata,
843                              Column('id', Integer, primary_key=True,
844                                     test_needs_autoincrement=True),
845                              Column('name', String(128)))
846        children_table = Table('Children', metadata,
847                               Column('id', Integer, primary_key=True,
848                                      test_needs_autoincrement=True),
849                               Column('parent_id', Integer,
850                                      ForeignKey('Parent.id')),
851                               Column('foo', String(128)),
852                               Column('name', String(128)))
853
854        class Parent(object):
855            children = association_proxy('_children', 'name')
856
857            def __init__(self, name):
858                self.name = name
859
860        class Child(object):
861            def __init__(self, name):
862                self.name = name
863
864
865        mapper(Child, children_table)
866        metadata.create_all()
867
868        self.metadata = metadata
869        self.session = create_session()
870        self.Parent, self.Child = Parent, Child
871        self.table = parents_table
872
873    def teardown(self):
874        self.metadata.drop_all()
875
876    def roundtrip(self, obj):
877        self.session.add(obj)
878        self.session.flush()
879        id, type_ = obj.id, type(obj)
880        self.session.expunge_all()
881        return self.session.query(type_).get(id)
882
883    def test_lazy_list(self):
884        Parent, Child = self.Parent, self.Child
885
886        mapper(Parent, self.table, properties={
887            '_children': relationship(Child, lazy='select',
888                                  collection_class=list)})
889
890        p = Parent('p')
891        p.children = ['a', 'b', 'c']
892
893        p = self.roundtrip(p)
894
895        # Is there a better way to ensure that the association_proxy
896        # didn't convert a lazy load to an eager load?  This does work though.
897        self.assert_('_children' not in p.__dict__)
898        self.assert_(len(p._children) == 3)
899        self.assert_('_children' in p.__dict__)
900
901    def test_eager_list(self):
902        Parent, Child = self.Parent, self.Child
903
904        mapper(Parent, self.table, properties={
905            '_children': relationship(Child, lazy='joined',
906                                  collection_class=list)})
907
908        p = Parent('p')
909        p.children = ['a', 'b', 'c']
910
911        p = self.roundtrip(p)
912
913        self.assert_('_children' in p.__dict__)
914        self.assert_(len(p._children) == 3)
915
916    def test_slicing_list(self):
917        Parent, Child = self.Parent, self.Child
918
919        mapper(Parent, self.table, properties={
920            '_children': relationship(Child, lazy='select',
921                                  collection_class=list)})
922
923        p = Parent('p')
924        p.children = ['a', 'b', 'c']
925
926        p = self.roundtrip(p)
927
928        self.assert_(len(p._children) == 3)
929        eq_('b', p.children[1])
930        eq_(['b', 'c'], p.children[-2:])
931
932    def test_lazy_scalar(self):
933        Parent, Child = self.Parent, self.Child
934
935        mapper(Parent, self.table, properties={
936            '_children': relationship(Child, lazy='select', uselist=False)})
937
938
939        p = Parent('p')
940        p.children = 'value'
941
942        p = self.roundtrip(p)
943
944        self.assert_('_children' not in p.__dict__)
945        self.assert_(p._children is not None)
946
947    def test_eager_scalar(self):
948        Parent, Child = self.Parent, self.Child
949
950        mapper(Parent, self.table, properties={
951            '_children': relationship(Child, lazy='joined', uselist=False)})
952
953
954        p = Parent('p')
955        p.children = 'value'
956
957        p = self.roundtrip(p)
958
959        self.assert_('_children' in p.__dict__)
960        self.assert_(p._children is not None)
961
962
963class Parent(object):
964    def __init__(self, name):
965        self.name = name
966
967class Child(object):
968    def __init__(self, name):
969        self.name = name
970
971class KVChild(object):
972    def __init__(self, name, value):
973        self.name = name
974        self.value = value
975
976class ReconstitutionTest(fixtures.TestBase):
977
978    def setup(self):
979        metadata = MetaData(testing.db)
980        parents = Table('parents', metadata, Column('id', Integer,
981                        primary_key=True,
982                        test_needs_autoincrement=True), Column('name',
983                        String(30)))
984        children = Table('children', metadata, Column('id', Integer,
985                         primary_key=True,
986                         test_needs_autoincrement=True),
987                         Column('parent_id', Integer,
988                         ForeignKey('parents.id')), Column('name',
989                         String(30)))
990        metadata.create_all()
991        parents.insert().execute(name='p1')
992        self.metadata = metadata
993        self.parents = parents
994        self.children = children
995        Parent.kids = association_proxy('children', 'name')
996
997    def teardown(self):
998        self.metadata.drop_all()
999        clear_mappers()
1000
1001    def test_weak_identity_map(self):
1002        mapper(Parent, self.parents,
1003               properties=dict(children=relationship(Child)))
1004        mapper(Child, self.children)
1005        session = create_session(weak_identity_map=True)
1006
1007        def add_child(parent_name, child_name):
1008            parent = \
1009                session.query(Parent).filter_by(name=parent_name).one()
1010            parent.kids.append(child_name)
1011
1012        add_child('p1', 'c1')
1013        gc_collect()
1014        add_child('p1', 'c2')
1015        session.flush()
1016        p = session.query(Parent).filter_by(name='p1').one()
1017        assert set(p.kids) == set(['c1', 'c2']), p.kids
1018
1019    def test_copy(self):
1020        mapper(Parent, self.parents,
1021               properties=dict(children=relationship(Child)))
1022        mapper(Child, self.children)
1023        p = Parent('p1')
1024        p.kids.extend(['c1', 'c2'])
1025        p_copy = copy.copy(p)
1026        del p
1027        gc_collect()
1028        assert set(p_copy.kids) == set(['c1', 'c2']), p.kids
1029
1030    def test_pickle_list(self):
1031        mapper(Parent, self.parents,
1032               properties=dict(children=relationship(Child)))
1033        mapper(Child, self.children)
1034        p = Parent('p1')
1035        p.kids.extend(['c1', 'c2'])
1036        r1 = pickle.loads(pickle.dumps(p))
1037        assert r1.kids == ['c1', 'c2']
1038
1039        # can't do this without parent having a cycle
1040        #r2 = pickle.loads(pickle.dumps(p.kids))
1041        #assert r2 == ['c1', 'c2']
1042
1043    def test_pickle_set(self):
1044        mapper(Parent, self.parents,
1045               properties=dict(children=relationship(Child,
1046               collection_class=set)))
1047        mapper(Child, self.children)
1048        p = Parent('p1')
1049        p.kids.update(['c1', 'c2'])
1050        r1 = pickle.loads(pickle.dumps(p))
1051        assert r1.kids == set(['c1', 'c2'])
1052
1053        # can't do this without parent having a cycle
1054        #r2 = pickle.loads(pickle.dumps(p.kids))
1055        #assert r2 == set(['c1', 'c2'])
1056
1057    def test_pickle_dict(self):
1058        mapper(Parent, self.parents,
1059               properties=dict(children=relationship(KVChild,
1060               collection_class=
1061                    collections.mapped_collection(PickleKeyFunc('name')))))
1062        mapper(KVChild, self.children)
1063        p = Parent('p1')
1064        p.kids.update({'c1': 'v1', 'c2': 'v2'})
1065        assert p.kids == {'c1': 'c1', 'c2': 'c2'}
1066        r1 = pickle.loads(pickle.dumps(p))
1067        assert r1.kids == {'c1': 'c1', 'c2': 'c2'}
1068
1069        # can't do this without parent having a cycle
1070        #r2 = pickle.loads(pickle.dumps(p.kids))
1071        #assert r2 == {'c1': 'c1', 'c2': 'c2'}
1072
1073class PickleKeyFunc(object):
1074    def __init__(self, name):
1075        self.name = name
1076
1077    def __call__(self, obj):
1078        return getattr(obj, self.name)
1079
1080class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
1081    __dialect__ = 'default'
1082
1083    run_inserts = 'once'
1084    run_deletes = None
1085    run_setup_mappers = 'once'
1086    run_setup_classes = 'once'
1087
1088    @classmethod
1089    def define_tables(cls, metadata):
1090        Table('userkeywords', metadata,
1091          Column('keyword_id', Integer, ForeignKey('keywords.id'), primary_key=True),
1092          Column('user_id', Integer, ForeignKey('users.id')),
1093          Column('value', String(50))
1094        )
1095        Table('users', metadata,
1096            Column('id', Integer,
1097              primary_key=True, test_needs_autoincrement=True),
1098            Column('name', String(64)),
1099            Column('singular_id', Integer, ForeignKey('singular.id'))
1100        )
1101        Table('keywords', metadata,
1102            Column('id', Integer,
1103              primary_key=True, test_needs_autoincrement=True),
1104            Column('keyword', String(64)),
1105            Column('singular_id', Integer, ForeignKey('singular.id'))
1106        )
1107        Table('singular', metadata,
1108            Column('id', Integer,
1109              primary_key=True, test_needs_autoincrement=True),
1110            Column('value', String(50))
1111        )
1112
1113    @classmethod
1114    def setup_classes(cls):
1115        class User(cls.Comparable):
1116            def __init__(self, name):
1117                self.name = name
1118
1119            # o2m -> m2o
1120            # uselist -> nonuselist
1121            keywords = association_proxy('user_keywords', 'keyword',
1122                    creator=lambda k: UserKeyword(keyword=k))
1123
1124            # m2o -> o2m
1125            # nonuselist -> uselist
1126            singular_keywords = association_proxy('singular', 'keywords')
1127
1128            # m2o -> scalar
1129            # nonuselist
1130            singular_value = association_proxy('singular', 'value')
1131
1132            # o2m -> scalar
1133            singular_collection = association_proxy('user_keywords', 'value')
1134
1135        class Keyword(cls.Comparable):
1136            def __init__(self, keyword):
1137                self.keyword = keyword
1138
1139            # o2o -> m2o
1140            # nonuselist -> nonuselist
1141            user = association_proxy('user_keyword', 'user')
1142
1143        class UserKeyword(cls.Comparable):
1144            def __init__(self, user=None, keyword=None):
1145                self.user = user
1146                self.keyword = keyword
1147
1148        class Singular(cls.Comparable):
1149            def __init__(self, value=None):
1150                self.value = value
1151
1152    @classmethod
1153    def setup_mappers(cls):
1154        users, Keyword, UserKeyword, singular, \
1155            userkeywords, User, keywords, Singular = (cls.tables.users,
1156                                cls.classes.Keyword,
1157                                cls.classes.UserKeyword,
1158                                cls.tables.singular,
1159                                cls.tables.userkeywords,
1160                                cls.classes.User,
1161                                cls.tables.keywords,
1162                                cls.classes.Singular)
1163
1164        mapper(User, users, properties={
1165            'singular': relationship(Singular)
1166        })
1167        mapper(Keyword, keywords, properties={
1168            'user_keyword': relationship(UserKeyword, uselist=False)
1169        })
1170
1171        mapper(UserKeyword, userkeywords, properties={
1172            'user': relationship(User, backref='user_keywords'),
1173            'keyword': relationship(Keyword)
1174        })
1175        mapper(Singular, singular, properties={
1176            'keywords': relationship(Keyword)
1177        })
1178
1179    @classmethod
1180    def insert_data(cls):
1181        UserKeyword, User, Keyword, Singular = (cls.classes.UserKeyword,
1182                                cls.classes.User,
1183                                cls.classes.Keyword,
1184                                cls.classes.Singular)
1185
1186        session = sessionmaker()()
1187        words = (
1188            'quick', 'brown',
1189            'fox', 'jumped', 'over',
1190            'the', 'lazy',
1191            )
1192        for ii in range(16):
1193            user = User('user%d' % ii)
1194
1195            if ii % 2 == 0:
1196                user.singular = Singular(value=("singular%d" % ii)
1197                                        if ii % 4 == 0 else None)
1198            session.add(user)
1199            for jj in words[(ii % len(words)):((ii + 3) % len(words))]:
1200                k = Keyword(jj)
1201                user.keywords.append(k)
1202                if ii % 2 == 0:
1203                    user.singular.keywords.append(k)
1204                    user.user_keywords[-1].value = "singular%d" % ii
1205
1206        orphan = Keyword('orphan')
1207        orphan.user_keyword = UserKeyword(keyword=orphan, user=None)
1208        session.add(orphan)
1209
1210        keyword_with_nothing = Keyword('kwnothing')
1211        session.add(keyword_with_nothing)
1212
1213        session.commit()
1214        cls.u = user
1215        cls.kw = user.keywords[0]
1216        cls.session = session
1217
1218    def _equivalent(self, q_proxy, q_direct):
1219        eq_(q_proxy.all(), q_direct.all())
1220
1221    def test_filter_any_criterion_ul_scalar(self):
1222        UserKeyword, User = self.classes.UserKeyword, self.classes.User
1223
1224        q1 = self.session.query(User).filter(
1225            User.singular_collection.any(UserKeyword.value == 'singular8'))
1226        self.assert_compile(
1227            q1,
1228            "SELECT users.id AS users_id, users.name AS users_name, "
1229            "users.singular_id AS users_singular_id "
1230            "FROM users "
1231            "WHERE EXISTS (SELECT 1 "
1232            "FROM userkeywords "
1233            "WHERE users.id = userkeywords.user_id AND "
1234            "userkeywords.value = :value_1)",
1235            checkparams={'value_1': 'singular8'}
1236        )
1237
1238        q2 = self.session.query(User).filter(
1239            User.user_keywords.any(UserKeyword.value == 'singular8'))
1240        self._equivalent(q1, q2)
1241
1242    def test_filter_any_kwarg_ul_nul(self):
1243        UserKeyword, User = self.classes.UserKeyword, self.classes.User
1244
1245        self._equivalent(self.session.query(User).
1246                    filter(User.keywords.any(keyword='jumped'
1247                         )),
1248                         self.session.query(User).filter(
1249                                User.user_keywords.any(
1250                            UserKeyword.keyword.has(keyword='jumped'
1251                         ))))
1252
1253    def test_filter_has_kwarg_nul_nul(self):
1254        UserKeyword, Keyword = self.classes.UserKeyword, self.classes.Keyword
1255
1256        self._equivalent(self.session.query(Keyword).
1257                    filter(Keyword.user.has(name='user2'
1258                         )),
1259                         self.session.query(Keyword).
1260                            filter(Keyword.user_keyword.has(
1261                            UserKeyword.user.has(name='user2'
1262                         ))))
1263
1264    def test_filter_has_kwarg_nul_ul(self):
1265        User, Singular = self.classes.User, self.classes.Singular
1266
1267        self._equivalent(
1268            self.session.query(User).\
1269                        filter(User.singular_keywords.any(keyword='jumped')),
1270            self.session.query(User).\
1271                        filter(
1272                            User.singular.has(
1273                                Singular.keywords.any(keyword='jumped')
1274                            )
1275                        )
1276        )
1277
1278    def test_filter_any_criterion_ul_nul(self):
1279        UserKeyword, User, Keyword = (self.classes.UserKeyword,
1280                                self.classes.User,
1281                                self.classes.Keyword)
1282
1283        self._equivalent(self.session.query(User).
1284                    filter(User.keywords.any(Keyword.keyword
1285                         == 'jumped')),
1286                         self.session.query(User).
1287                            filter(User.user_keywords.any(
1288                            UserKeyword.keyword.has(Keyword.keyword
1289                         == 'jumped'))))
1290
1291    def test_filter_has_criterion_nul_nul(self):
1292        UserKeyword, User, Keyword = (self.classes.UserKeyword,
1293                                self.classes.User,
1294                                self.classes.Keyword)
1295
1296        self._equivalent(self.session.query(Keyword).
1297                filter(Keyword.user.has(User.name == 'user2')),
1298                         self.session.query(Keyword).
1299                            filter(Keyword.user_keyword.has(
1300                                UserKeyword.user.has(User.name == 'user2'))))
1301
1302    def test_filter_any_criterion_nul_ul(self):
1303        User, Keyword, Singular = (self.classes.User,
1304                                self.classes.Keyword,
1305                                self.classes.Singular)
1306
1307        self._equivalent(
1308            self.session.query(User).
1309                        filter(User.singular_keywords.any(
1310                            Keyword.keyword == 'jumped')),
1311            self.session.query(User).
1312                        filter(
1313                            User.singular.has(
1314                                Singular.keywords.any(Keyword.keyword == 'jumped')
1315                            )
1316                        )
1317        )
1318
1319    def test_filter_contains_ul_nul(self):
1320        User = self.classes.User
1321
1322        self._equivalent(self.session.query(User).
1323        filter(User.keywords.contains(self.kw)),
1324                         self.session.query(User).
1325                         filter(User.user_keywords.any(keyword=self.kw)))
1326
1327    def test_filter_contains_nul_ul(self):
1328        User, Singular = self.classes.User, self.classes.Singular
1329
1330        with expect_warnings(
1331                "Got None for value of column keywords.singular_id;"):
1332            self._equivalent(
1333                self.session.query(User).filter(
1334                                User.singular_keywords.contains(self.kw)
1335                ),
1336                self.session.query(User).filter(
1337                                User.singular.has(
1338                                    Singular.keywords.contains(self.kw)
1339                                )
1340                ),
1341            )
1342
1343    def test_filter_eq_nul_nul(self):
1344        Keyword = self.classes.Keyword
1345
1346        self._equivalent(self.session.query(Keyword).filter(Keyword.user
1347                         == self.u),
1348                         self.session.query(Keyword).
1349                         filter(Keyword.user_keyword.has(user=self.u)))
1350
1351    def test_filter_ne_nul_nul(self):
1352        Keyword = self.classes.Keyword
1353
1354        self._equivalent(self.session.query(Keyword).filter(Keyword.user != self.u),
1355                         self.session.query(Keyword).
1356                            filter(
1357                                    Keyword.user_keyword.has(Keyword.user != self.u)
1358                            )
1359                        )
1360
1361    def test_filter_eq_null_nul_nul(self):
1362        UserKeyword, Keyword = self.classes.UserKeyword, self.classes.Keyword
1363
1364        self._equivalent(
1365                self.session.query(Keyword).filter(Keyword.user == None),
1366                self.session.query(Keyword).
1367                            filter(
1368                                or_(
1369                                    Keyword.user_keyword.has(UserKeyword.user == None),
1370                                    Keyword.user_keyword == None
1371                                )
1372
1373                            )
1374                        )
1375
1376    def test_filter_ne_null_nul_nul(self):
1377        UserKeyword, Keyword = self.classes.UserKeyword, self.classes.Keyword
1378
1379        self._equivalent(
1380                self.session.query(Keyword).filter(Keyword.user != None),
1381                self.session.query(Keyword).
1382                            filter(
1383                                Keyword.user_keyword.has(UserKeyword.user != None),
1384                            )
1385                        )
1386
1387    def test_filter_eq_None_nul(self):
1388        User = self.classes.User
1389        Singular = self.classes.Singular
1390
1391        self._equivalent(
1392            self.session.query(User).filter(User.singular_value == None),
1393            self.session.query(User).filter(
1394                    or_(
1395                        User.singular.has(Singular.value == None),
1396                        User.singular == None
1397                    )
1398                )
1399        )
1400
1401    def test_filter_ne_value_nul(self):
1402        User = self.classes.User
1403        Singular = self.classes.Singular
1404
1405        self._equivalent(
1406            self.session.query(User).filter(User.singular_value != "singular4"),
1407            self.session.query(User).filter(
1408                        User.singular.has(Singular.value != "singular4"),
1409                )
1410        )
1411
1412    def test_filter_eq_value_nul(self):
1413        User = self.classes.User
1414        Singular = self.classes.Singular
1415
1416        self._equivalent(
1417            self.session.query(User).filter(User.singular_value == "singular4"),
1418            self.session.query(User).filter(
1419                        User.singular.has(Singular.value == "singular4"),
1420                )
1421        )
1422
1423    def test_filter_ne_None_nul(self):
1424        User = self.classes.User
1425        Singular = self.classes.Singular
1426
1427        self._equivalent(
1428            self.session.query(User).filter(User.singular_value != None),
1429            self.session.query(User).filter(
1430                        User.singular.has(Singular.value != None),
1431                )
1432        )
1433
1434    def test_has_nul(self):
1435        # a special case where we provide an empty has() on a
1436        # non-object-targeted association proxy.
1437        User = self.classes.User
1438        self.classes.Singular
1439
1440        self._equivalent(
1441            self.session.query(User).filter(User.singular_value.has()),
1442            self.session.query(User).filter(
1443                        User.singular.has(),
1444                )
1445        )
1446
1447    def test_nothas_nul(self):
1448        # a special case where we provide an empty has() on a
1449        # non-object-targeted association proxy.
1450        User = self.classes.User
1451        self.classes.Singular
1452
1453        self._equivalent(
1454            self.session.query(User).filter(~User.singular_value.has()),
1455            self.session.query(User).filter(
1456                        ~User.singular.has(),
1457                )
1458        )
1459
1460    def test_has_criterion_nul(self):
1461        # but we don't allow that with any criterion...
1462        User = self.classes.User
1463        self.classes.Singular
1464
1465        assert_raises_message(
1466            exc.ArgumentError,
1467            "Non-empty has\(\) not allowed",
1468            User.singular_value.has,
1469            User.singular_value == "singular4"
1470        )
1471
1472    def test_has_kwargs_nul(self):
1473        # ... or kwargs
1474        User = self.classes.User
1475        self.classes.Singular
1476
1477        assert_raises_message(
1478            exc.ArgumentError,
1479            "Non-empty has\(\) not allowed",
1480            User.singular_value.has, singular_value="singular4"
1481        )
1482
1483    def test_filter_scalar_contains_fails_nul_nul(self):
1484        Keyword = self.classes.Keyword
1485
1486        assert_raises(exc.InvalidRequestError,
1487                lambda: Keyword.user.contains(self.u))
1488
1489    def test_filter_scalar_any_fails_nul_nul(self):
1490        Keyword = self.classes.Keyword
1491
1492        assert_raises(exc.InvalidRequestError,
1493                lambda: Keyword.user.any(name='user2'))
1494
1495    def test_filter_collection_has_fails_ul_nul(self):
1496        User = self.classes.User
1497
1498        assert_raises(exc.InvalidRequestError,
1499                lambda: User.keywords.has(keyword='quick'))
1500
1501    def test_filter_collection_eq_fails_ul_nul(self):
1502        User = self.classes.User
1503
1504        assert_raises(exc.InvalidRequestError,
1505                lambda: User.keywords == self.kw)
1506
1507    def test_filter_collection_ne_fails_ul_nul(self):
1508        User = self.classes.User
1509
1510        assert_raises(exc.InvalidRequestError,
1511                lambda: User.keywords != self.kw)
1512
1513    def test_join_separate_attr(self):
1514        User = self.classes.User
1515        self.assert_compile(
1516            self.session.query(User).join(
1517                        User.keywords.local_attr,
1518                        User.keywords.remote_attr),
1519            "SELECT users.id AS users_id, users.name AS users_name, "
1520            "users.singular_id AS users_singular_id "
1521            "FROM users JOIN userkeywords ON users.id = "
1522            "userkeywords.user_id JOIN keywords ON keywords.id = "
1523            "userkeywords.keyword_id"
1524        )
1525
1526    def test_join_single_attr(self):
1527        User = self.classes.User
1528        self.assert_compile(
1529            self.session.query(User).join(
1530                        *User.keywords.attr),
1531            "SELECT users.id AS users_id, users.name AS users_name, "
1532            "users.singular_id AS users_singular_id "
1533            "FROM users JOIN userkeywords ON users.id = "
1534            "userkeywords.user_id JOIN keywords ON keywords.id = "
1535            "userkeywords.keyword_id"
1536        )
1537
1538class DictOfTupleUpdateTest(fixtures.TestBase):
1539    def setup(self):
1540        class B(object):
1541            def __init__(self, key, elem):
1542                self.key = key
1543                self.elem = elem
1544
1545        class A(object):
1546            elements = association_proxy("orig", "elem", creator=B)
1547
1548        m = MetaData()
1549        a = Table('a', m, Column('id', Integer, primary_key=True))
1550        b = Table('b', m, Column('id', Integer, primary_key=True),
1551                    Column('aid', Integer, ForeignKey('a.id')))
1552        mapper(A, a, properties={
1553            'orig': relationship(B, collection_class=attribute_mapped_collection('key'))
1554        })
1555        mapper(B, b)
1556        self.A = A
1557        self.B = B
1558
1559    def test_update_one_elem_dict(self):
1560        a1 = self.A()
1561        a1.elements.update({("B", 3): 'elem2'})
1562        eq_(a1.elements, {("B", 3): 'elem2'})
1563
1564    def test_update_multi_elem_dict(self):
1565        a1 = self.A()
1566        a1.elements.update({("B", 3): 'elem2', ("C", 4): "elem3"})
1567        eq_(a1.elements, {("B", 3): 'elem2', ("C", 4): "elem3"})
1568
1569    def test_update_one_elem_list(self):
1570        a1 = self.A()
1571        a1.elements.update([(("B", 3), 'elem2')])
1572        eq_(a1.elements, {("B", 3): 'elem2'})
1573
1574    def test_update_multi_elem_list(self):
1575        a1 = self.A()
1576        a1.elements.update([(("B", 3), 'elem2'), (("C", 4), "elem3")])
1577        eq_(a1.elements, {("B", 3): 'elem2', ("C", 4): "elem3"})
1578
1579    def test_update_one_elem_varg(self):
1580        a1 = self.A()
1581        assert_raises_message(
1582            ValueError,
1583            "dictionary update sequence requires "
1584            "2-element tuples",
1585            a1.elements.update, (("B", 3), 'elem2')
1586        )
1587
1588    def test_update_multi_elem_varg(self):
1589        a1 = self.A()
1590        assert_raises_message(
1591            TypeError,
1592            "update expected at most 1 arguments, got 2",
1593            a1.elements.update,
1594            (("B", 3), 'elem2'), (("C", 4), "elem3")
1595        )
1596
1597
1598class InfoTest(fixtures.TestBase):
1599    def test_constructor(self):
1600        assoc = association_proxy('a', 'b', info={'some_assoc': 'some_value'})
1601        eq_(assoc.info, {"some_assoc": "some_value"})
1602
1603    def test_empty(self):
1604        assoc = association_proxy('a', 'b')
1605        eq_(assoc.info, {})
1606
1607    def test_via_cls(self):
1608        class Foob(object):
1609            assoc = association_proxy('a', 'b')
1610
1611        eq_(Foob.assoc.info, {})
1612
1613        Foob.assoc.info["foo"] = 'bar'
1614
1615        eq_(Foob.assoc.info, {'foo': 'bar'})
1616