1import pytest 2import sqlalchemy as sa 3 4from sqlalchemy_utils import merge_references 5 6 7class TestMergeReferences(object): 8 9 @pytest.fixture 10 def User(self, Base): 11 class User(Base): 12 __tablename__ = 'user' 13 id = sa.Column(sa.Integer, primary_key=True) 14 name = sa.Column(sa.Unicode(255)) 15 16 def __repr__(self): 17 return 'User(%r)' % self.name 18 return User 19 20 @pytest.fixture 21 def BlogPost(self, Base, User): 22 class BlogPost(Base): 23 __tablename__ = 'blog_post' 24 id = sa.Column(sa.Integer, primary_key=True) 25 title = sa.Column(sa.Unicode(255)) 26 content = sa.Column(sa.UnicodeText) 27 author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) 28 29 author = sa.orm.relationship(User) 30 return BlogPost 31 32 @pytest.fixture 33 def init_models(self, User, BlogPost): 34 pass 35 36 def test_updates_foreign_keys(self, session, User, BlogPost): 37 john = User(name=u'John') 38 jack = User(name=u'Jack') 39 post = BlogPost(title=u'Some title', author=john) 40 post2 = BlogPost(title=u'Other title', author=jack) 41 session.add(john) 42 session.add(jack) 43 session.add(post) 44 session.add(post2) 45 session.commit() 46 merge_references(john, jack) 47 session.commit() 48 assert post.author == jack 49 assert post2.author == jack 50 51 def test_object_merging_whenever_possible(self, session, User, BlogPost): 52 john = User(name=u'John') 53 jack = User(name=u'Jack') 54 post = BlogPost(title=u'Some title', author=john) 55 post2 = BlogPost(title=u'Other title', author=jack) 56 session.add(john) 57 session.add(jack) 58 session.add(post) 59 session.add(post2) 60 session.commit() 61 # Load the author for post 62 assert post.author_id == john.id 63 merge_references(john, jack) 64 assert post.author_id == jack.id 65 assert post2.author_id == jack.id 66 67 68class TestMergeReferencesWithManyToManyAssociations(object): 69 70 @pytest.fixture 71 def User(self, Base): 72 class User(Base): 73 __tablename__ = 'user' 74 id = sa.Column(sa.Integer, primary_key=True) 75 name = sa.Column(sa.Unicode(255)) 76 77 def __repr__(self): 78 return 'User(%r)' % self.name 79 return User 80 81 @pytest.fixture 82 def Team(self, Base): 83 team_member = sa.Table( 84 'team_member', Base.metadata, 85 sa.Column( 86 'user_id', sa.Integer, 87 sa.ForeignKey('user.id', ondelete='CASCADE'), 88 primary_key=True 89 ), 90 sa.Column( 91 'team_id', sa.Integer, 92 sa.ForeignKey('team.id', ondelete='CASCADE'), 93 primary_key=True 94 ) 95 ) 96 97 class Team(Base): 98 __tablename__ = 'team' 99 id = sa.Column(sa.Integer, primary_key=True) 100 name = sa.Column(sa.Unicode(255)) 101 102 members = sa.orm.relationship( 103 'User', 104 secondary=team_member, 105 backref='teams' 106 ) 107 return Team 108 109 @pytest.fixture 110 def init_models(self, User, Team): 111 pass 112 113 def test_supports_associations(self, session, User, Team): 114 john = User(name=u'John') 115 jack = User(name=u'Jack') 116 team = Team(name=u'Team') 117 team.members.append(john) 118 session.add(john) 119 session.add(jack) 120 session.commit() 121 merge_references(john, jack) 122 assert john not in team.members 123 assert jack in team.members 124 125 126class TestMergeReferencesWithManyToManyAssociationObjects(object): 127 128 @pytest.fixture 129 def Team(self, Base): 130 class Team(Base): 131 __tablename__ = 'team' 132 id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) 133 name = sa.Column(sa.Unicode(255)) 134 return Team 135 136 @pytest.fixture 137 def User(self, Base): 138 class User(Base): 139 __tablename__ = 'user' 140 id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) 141 name = sa.Column(sa.Unicode(255)) 142 return User 143 144 @pytest.fixture 145 def TeamMember(self, Base, User, Team): 146 class TeamMember(Base): 147 __tablename__ = 'team_member' 148 user_id = sa.Column( 149 sa.Integer, 150 sa.ForeignKey(User.id, ondelete='CASCADE'), 151 primary_key=True 152 ) 153 team_id = sa.Column( 154 sa.Integer, 155 sa.ForeignKey(Team.id, ondelete='CASCADE'), 156 primary_key=True 157 ) 158 role = sa.Column(sa.Unicode(255)) 159 team = sa.orm.relationship( 160 Team, 161 backref=sa.orm.backref( 162 'members', 163 cascade='all, delete-orphan' 164 ), 165 primaryjoin=team_id == Team.id, 166 ) 167 user = sa.orm.relationship( 168 User, 169 backref=sa.orm.backref( 170 'memberships', 171 cascade='all, delete-orphan' 172 ), 173 primaryjoin=user_id == User.id, 174 ) 175 return TeamMember 176 177 @pytest.fixture 178 def init_models(self, User, Team, TeamMember): 179 pass 180 181 def test_supports_associations(self, session, User, Team, TeamMember): 182 john = User(name=u'John') 183 jack = User(name=u'Jack') 184 team = Team(name=u'Team') 185 team.members.append(TeamMember(user=john)) 186 session.add(john) 187 session.add(jack) 188 session.add(team) 189 session.commit() 190 merge_references(john, jack) 191 session.commit() 192 users = [member.user for member in team.members] 193 assert john not in users 194 assert jack in users 195