1from __future__ import absolute_import, print_function, division
2
3import unittest
4
5from pony.orm.core import *
6from pony.orm.tests.testutils import raises_exception
7from pony.orm.tests import setup_database, teardown_database, only_for
8
9db = Database()
10
11
12class Person(db.Entity):
13    name = Required(unicode)
14    spouse = Optional('Person', reverse='spouse')
15
16
17class TestSymmetricOne2One(unittest.TestCase):
18    @classmethod
19    def setUpClass(cls):
20        setup_database(db)
21
22    @classmethod
23    def tearDownClass(cls):
24        teardown_database(db)
25
26    def setUp(self):
27        with db_session:
28            db.execute('update person set spouse=null')
29            db.execute('delete from person')
30            db.insert(Person, id=1, name='A')
31            db.insert(Person, id=2, name='B', spouse=1)
32            db.execute('update person set spouse=2 where id=1')
33            db.insert(Person, id=3, name='C')
34            db.insert(Person, id=4, name='D', spouse=3)
35            db.execute('update person set spouse=4 where id=3')
36            db.insert(Person, id=5, name='E', spouse=None)
37        db_session.__enter__()
38    def tearDown(self):
39        db_session.__exit__()
40    def test1(self):
41        p1 = Person[1]
42        p2 = Person[2]
43        p5 = Person[5]
44        p1.spouse = p5
45        commit()
46        self.assertEqual(p1._vals_.get(Person.spouse), p5)
47        self.assertEqual(p5._vals_.get(Person.spouse), p1)
48        self.assertEqual(p2._vals_.get(Person.spouse), None)
49        data = db.select('spouse from person order by id')
50        self.assertEqual([5, None, 4, 3, 1], data)
51    def test2(self):
52        p1 = Person[1]
53        p2 = Person[2]
54        p1.spouse = None
55        commit()
56        self.assertEqual(p1._vals_.get(Person.spouse), None)
57        self.assertEqual(p2._vals_.get(Person.spouse), None)
58        data = db.select('spouse from person order by id')
59        self.assertEqual([None, None, 4, 3, None], data)
60    def test3(self):
61        p1 = Person[1]
62        p2 = Person[2]
63        p3 = Person[3]
64        p4 = Person[4]
65        p1.spouse = p3
66        commit()
67        self.assertEqual(p1._vals_.get(Person.spouse), p3)
68        self.assertEqual(p2._vals_.get(Person.spouse), None)
69        self.assertEqual(p3._vals_.get(Person.spouse), p1)
70        self.assertEqual(p4._vals_.get(Person.spouse), None)
71        data = db.select('spouse from person order by id')
72        self.assertEqual([3, None, 1, None, None], data)
73    def test4(self):
74        persons = set(select(p for p in Person if p.spouse.name in ('B', 'D')))
75        self.assertEqual(persons, {Person[1], Person[3]})
76    @raises_exception(UnrepeatableReadError, 'Multiple Person objects linked with the same Person[2] object. '
77                                             'Maybe Person.spouse attribute should be Set instead of Optional')
78    def test5(self):
79        db.execute('update person set spouse = 3 where id = 2')
80        p1 = Person[1]
81        p1.spouse
82        p2 = Person[2]
83        p2.name
84    def test6(self):
85        db.execute('update person set spouse = 3 where id = 2')
86        p1 = Person[1]
87        p2 = Person[2]
88        p2.name
89        p1.spouse
90        self.assertEqual(p2._vals_.get(Person.spouse), p1)
91
92if __name__ == '__main__':
93    unittest.main()
94