1from __future__ import absolute_import, print_function, division 2 3import unittest 4from pony.orm.core import * 5from pony.orm.tests import setup_database, teardown_database 6 7db = Database() 8 9 10class Person(db.Entity): 11 name = Required(unicode) 12 friends = Set('Person', reverse='friends') 13 14 15class TestSymmetricM2M(unittest.TestCase): 16 @classmethod 17 def setUpClass(cls): 18 setup_database(db) 19 20 @classmethod 21 def tearDownClass(cls): 22 teardown_database(db) 23 24 def setUp(self): 25 with db_session: 26 for p in Person.select(): p.delete() 27 with db_session: 28 db.insert(Person, id=1, name='A') 29 db.insert(Person, id=2, name='B') 30 db.insert(Person, id=3, name='C') 31 db.insert(Person, id=4, name='D') 32 db.insert(Person, id=5, name='E') 33 db.insert(Person.friends, person=1, person_2=2) 34 db.insert(Person.friends, person=2, person_2=1) 35 db.insert(Person.friends, person=1, person_2=3) 36 db.insert(Person.friends, person=3, person_2=1) 37 db_session.__enter__() 38 def tearDown(self): 39 rollback() 40 db_session.__exit__() 41 def test1a(self): 42 p1 = Person[1] 43 p4 = Person[4] 44 p1.friends.add(p4) 45 self.assertEqual(set(p4.friends), {p1}) 46 def test1b(self): 47 p1 = Person[1] 48 p4 = Person[4] 49 p1.friends.add(p4) 50 self.assertEqual(set(p1.friends), {Person[2], Person[3], p4}) 51 def test1c(self): 52 p1 = Person[1] 53 p4 = Person[4] 54 p1.friends.add(p4) 55 commit() 56 rows = db.select("* from person_friends order by person, person_2") 57 self.assertEqual(rows, [(1,2), (1,3), (1,4), (2,1), (3,1), (4,1)]) 58 def test2a(self): 59 p1 = Person[1] 60 p2 = Person[2] 61 p1.friends.remove(p2) 62 self.assertEqual(set(p1.friends), {Person[3]}) 63 def test2b(self): 64 p1 = Person[1] 65 p2 = Person[2] 66 p1.friends.remove(p2) 67 self.assertEqual(set(Person[3].friends), {p1}) 68 def test2c(self): 69 p1 = Person[1] 70 p2 = Person[2] 71 p1.friends.remove(p2) 72 self.assertEqual(set(p2.friends), set()) 73 def test2d(self): 74 p1 = Person[1] 75 p2 = Person[2] 76 p1.friends.remove(p2) 77 commit() 78 rows = db.select("* from person_friends order by person, person_2") 79 self.assertEqual(rows, [(1,3), (3,1)]) 80 def test3a(self): 81 db.execute('delete from person_friends') 82 db.insert(Person.friends, person=1, person_2=2) 83 p1 = Person[1] 84 p2 = Person[2] 85 p2_friends = set(p2.friends) 86 self.assertEqual(p2_friends, set()) 87 try: p1_friends = set(p1.friends) 88 except UnrepeatableReadError as e: self.assertEqual(e.args[0], 89 "Phantom object Person[1] appeared in collection Person[2].friends") 90 else: self.fail() 91 def test3b(self): 92 db.execute('delete from person_friends') 93 db.insert(Person.friends, person=1, person_2=2) 94 p1 = Person[1] 95 p2 = Person[2] 96 p1_friends = set(p1.friends) 97 self.assertEqual(p1_friends, {p2}) 98 try: p2_friends = set(p2.friends) 99 except UnrepeatableReadError as e: self.assertEqual(e.args[0], 100 "Phantom object Person[1] disappeared from collection Person[2].friends") 101 else: self.fail() 102 103if __name__ == '__main__': 104 unittest.main() 105