1from __future__ import absolute_import, print_function, division
2
3import unittest
4from pony.orm.core import *
5from pony.orm.tests import setup_database, teardown_database
6
7db = Database()
8
9
10class Person(db.Entity):
11    name = Required(unicode)
12    friends = Set('Person', reverse='friends')
13
14
15class TestSymmetricM2M(unittest.TestCase):
16    @classmethod
17    def setUpClass(cls):
18        setup_database(db)
19
20    @classmethod
21    def tearDownClass(cls):
22        teardown_database(db)
23
24    def setUp(self):
25        with db_session:
26            for p in Person.select(): p.delete()
27        with db_session:
28            db.insert(Person, id=1, name='A')
29            db.insert(Person, id=2, name='B')
30            db.insert(Person, id=3, name='C')
31            db.insert(Person, id=4, name='D')
32            db.insert(Person, id=5, name='E')
33            db.insert(Person.friends, person=1, person_2=2)
34            db.insert(Person.friends, person=2, person_2=1)
35            db.insert(Person.friends, person=1, person_2=3)
36            db.insert(Person.friends, person=3, person_2=1)
37        db_session.__enter__()
38    def tearDown(self):
39        rollback()
40        db_session.__exit__()
41    def test1a(self):
42        p1 = Person[1]
43        p4 = Person[4]
44        p1.friends.add(p4)
45        self.assertEqual(set(p4.friends), {p1})
46    def test1b(self):
47        p1 = Person[1]
48        p4 = Person[4]
49        p1.friends.add(p4)
50        self.assertEqual(set(p1.friends), {Person[2], Person[3], p4})
51    def test1c(self):
52        p1 = Person[1]
53        p4 = Person[4]
54        p1.friends.add(p4)
55        commit()
56        rows = db.select("* from person_friends order by person, person_2")
57        self.assertEqual(rows, [(1,2), (1,3), (1,4), (2,1), (3,1), (4,1)])
58    def test2a(self):
59        p1 = Person[1]
60        p2 = Person[2]
61        p1.friends.remove(p2)
62        self.assertEqual(set(p1.friends), {Person[3]})
63    def test2b(self):
64        p1 = Person[1]
65        p2 = Person[2]
66        p1.friends.remove(p2)
67        self.assertEqual(set(Person[3].friends), {p1})
68    def test2c(self):
69        p1 = Person[1]
70        p2 = Person[2]
71        p1.friends.remove(p2)
72        self.assertEqual(set(p2.friends), set())
73    def test2d(self):
74        p1 = Person[1]
75        p2 = Person[2]
76        p1.friends.remove(p2)
77        commit()
78        rows = db.select("* from person_friends order by person, person_2")
79        self.assertEqual(rows, [(1,3), (3,1)])
80    def test3a(self):
81        db.execute('delete from person_friends')
82        db.insert(Person.friends, person=1, person_2=2)
83        p1 = Person[1]
84        p2 = Person[2]
85        p2_friends = set(p2.friends)
86        self.assertEqual(p2_friends, set())
87        try: p1_friends = set(p1.friends)
88        except UnrepeatableReadError as e: self.assertEqual(e.args[0],
89            "Phantom object Person[1] appeared in collection Person[2].friends")
90        else: self.fail()
91    def test3b(self):
92        db.execute('delete from person_friends')
93        db.insert(Person.friends, person=1, person_2=2)
94        p1 = Person[1]
95        p2 = Person[2]
96        p1_friends = set(p1.friends)
97        self.assertEqual(p1_friends, {p2})
98        try: p2_friends = set(p2.friends)
99        except UnrepeatableReadError as e: self.assertEqual(e.args[0],
100            "Phantom object Person[1] disappeared from collection Person[2].friends")
101        else: self.fail()
102
103if __name__ == '__main__':
104    unittest.main()
105