1import pytest
2import sqlalchemy as sa
3
4from sqlalchemy_utils import merge_references
5
6
7class TestMergeReferences(object):
8
9    @pytest.fixture
10    def User(self, Base):
11        class User(Base):
12            __tablename__ = 'user'
13            id = sa.Column(sa.Integer, primary_key=True)
14            name = sa.Column(sa.Unicode(255))
15
16            def __repr__(self):
17                return 'User(%r)' % self.name
18        return User
19
20    @pytest.fixture
21    def BlogPost(self, Base, User):
22        class BlogPost(Base):
23            __tablename__ = 'blog_post'
24            id = sa.Column(sa.Integer, primary_key=True)
25            title = sa.Column(sa.Unicode(255))
26            content = sa.Column(sa.UnicodeText)
27            author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
28
29            author = sa.orm.relationship(User)
30        return BlogPost
31
32    @pytest.fixture
33    def init_models(self, User, BlogPost):
34        pass
35
36    def test_updates_foreign_keys(self, session, User, BlogPost):
37        john = User(name=u'John')
38        jack = User(name=u'Jack')
39        post = BlogPost(title=u'Some title', author=john)
40        post2 = BlogPost(title=u'Other title', author=jack)
41        session.add(john)
42        session.add(jack)
43        session.add(post)
44        session.add(post2)
45        session.commit()
46        merge_references(john, jack)
47        session.commit()
48        assert post.author == jack
49        assert post2.author == jack
50
51    def test_object_merging_whenever_possible(self, session, User, BlogPost):
52        john = User(name=u'John')
53        jack = User(name=u'Jack')
54        post = BlogPost(title=u'Some title', author=john)
55        post2 = BlogPost(title=u'Other title', author=jack)
56        session.add(john)
57        session.add(jack)
58        session.add(post)
59        session.add(post2)
60        session.commit()
61        # Load the author for post
62        assert post.author_id == john.id
63        merge_references(john, jack)
64        assert post.author_id == jack.id
65        assert post2.author_id == jack.id
66
67
68class TestMergeReferencesWithManyToManyAssociations(object):
69
70    @pytest.fixture
71    def User(self, Base):
72        class User(Base):
73            __tablename__ = 'user'
74            id = sa.Column(sa.Integer, primary_key=True)
75            name = sa.Column(sa.Unicode(255))
76
77            def __repr__(self):
78                return 'User(%r)' % self.name
79        return User
80
81    @pytest.fixture
82    def Team(self, Base):
83        team_member = sa.Table(
84            'team_member', Base.metadata,
85            sa.Column(
86                'user_id', sa.Integer,
87                sa.ForeignKey('user.id', ondelete='CASCADE'),
88                primary_key=True
89            ),
90            sa.Column(
91                'team_id', sa.Integer,
92                sa.ForeignKey('team.id', ondelete='CASCADE'),
93                primary_key=True
94            )
95        )
96
97        class Team(Base):
98            __tablename__ = 'team'
99            id = sa.Column(sa.Integer, primary_key=True)
100            name = sa.Column(sa.Unicode(255))
101
102            members = sa.orm.relationship(
103                'User',
104                secondary=team_member,
105                backref='teams'
106            )
107        return Team
108
109    @pytest.fixture
110    def init_models(self, User, Team):
111        pass
112
113    def test_supports_associations(self, session, User, Team):
114        john = User(name=u'John')
115        jack = User(name=u'Jack')
116        team = Team(name=u'Team')
117        team.members.append(john)
118        session.add(john)
119        session.add(jack)
120        session.commit()
121        merge_references(john, jack)
122        assert john not in team.members
123        assert jack in team.members
124
125
126class TestMergeReferencesWithManyToManyAssociationObjects(object):
127
128    @pytest.fixture
129    def Team(self, Base):
130        class Team(Base):
131            __tablename__ = 'team'
132            id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
133            name = sa.Column(sa.Unicode(255))
134        return Team
135
136    @pytest.fixture
137    def User(self, Base):
138        class User(Base):
139            __tablename__ = 'user'
140            id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
141            name = sa.Column(sa.Unicode(255))
142        return User
143
144    @pytest.fixture
145    def TeamMember(self, Base, User, Team):
146        class TeamMember(Base):
147            __tablename__ = 'team_member'
148            user_id = sa.Column(
149                sa.Integer,
150                sa.ForeignKey(User.id, ondelete='CASCADE'),
151                primary_key=True
152            )
153            team_id = sa.Column(
154                sa.Integer,
155                sa.ForeignKey(Team.id, ondelete='CASCADE'),
156                primary_key=True
157            )
158            role = sa.Column(sa.Unicode(255))
159            team = sa.orm.relationship(
160                Team,
161                backref=sa.orm.backref(
162                    'members',
163                    cascade='all, delete-orphan'
164                ),
165                primaryjoin=team_id == Team.id,
166            )
167            user = sa.orm.relationship(
168                User,
169                backref=sa.orm.backref(
170                    'memberships',
171                    cascade='all, delete-orphan'
172                ),
173                primaryjoin=user_id == User.id,
174            )
175        return TeamMember
176
177    @pytest.fixture
178    def init_models(self, User, Team, TeamMember):
179        pass
180
181    def test_supports_associations(self, session, User, Team, TeamMember):
182        john = User(name=u'John')
183        jack = User(name=u'Jack')
184        team = Team(name=u'Team')
185        team.members.append(TeamMember(user=john))
186        session.add(john)
187        session.add(jack)
188        session.add(team)
189        session.commit()
190        merge_references(john, jack)
191        session.commit()
192        users = [member.user for member in team.members]
193        assert john not in users
194        assert jack in users
195