1import sqlalchemy as sa
2from sqlalchemy import ForeignKey
3from sqlalchemy import Integer
4from sqlalchemy import String
5from sqlalchemy import testing
6from sqlalchemy.orm import backref
7from sqlalchemy.orm import exc as orm_exc
8from sqlalchemy.orm import relationship
9from sqlalchemy.testing import assert_raises_message
10from sqlalchemy.testing import eq_
11from sqlalchemy.testing import fixtures
12from sqlalchemy.testing.fixtures import fixture_session
13from sqlalchemy.testing.schema import Column
14from sqlalchemy.testing.schema import Table
15
16
17class M2MTest(fixtures.MappedTest):
18    @classmethod
19    def define_tables(cls, metadata):
20        Table(
21            "place",
22            metadata,
23            Column(
24                "place_id",
25                Integer,
26                test_needs_autoincrement=True,
27                primary_key=True,
28            ),
29            Column("name", String(30), nullable=False),
30            test_needs_acid=True,
31        )
32
33        Table(
34            "transition",
35            metadata,
36            Column(
37                "transition_id",
38                Integer,
39                test_needs_autoincrement=True,
40                primary_key=True,
41            ),
42            Column("name", String(30), nullable=False),
43            test_needs_acid=True,
44        )
45
46        Table(
47            "place_thingy",
48            metadata,
49            Column(
50                "thingy_id",
51                Integer,
52                test_needs_autoincrement=True,
53                primary_key=True,
54            ),
55            Column(
56                "place_id",
57                Integer,
58                ForeignKey("place.place_id"),
59                nullable=False,
60            ),
61            Column("name", String(30), nullable=False),
62            test_needs_acid=True,
63        )
64
65        # association table #1
66        Table(
67            "place_input",
68            metadata,
69            Column("place_id", Integer, ForeignKey("place.place_id")),
70            Column(
71                "transition_id",
72                Integer,
73                ForeignKey("transition.transition_id"),
74            ),
75            test_needs_acid=True,
76        )
77
78        # association table #2
79        Table(
80            "place_output",
81            metadata,
82            Column("place_id", Integer, ForeignKey("place.place_id")),
83            Column(
84                "transition_id",
85                Integer,
86                ForeignKey("transition.transition_id"),
87            ),
88            test_needs_acid=True,
89        )
90
91        Table(
92            "place_place",
93            metadata,
94            Column("pl1_id", Integer, ForeignKey("place.place_id")),
95            Column("pl2_id", Integer, ForeignKey("place.place_id")),
96            test_needs_acid=True,
97        )
98
99    @classmethod
100    def setup_classes(cls):
101        class Place(cls.Basic):
102            def __init__(self, name):
103                self.name = name
104
105        class PlaceThingy(cls.Basic):
106            def __init__(self, name):
107                self.name = name
108
109        class Transition(cls.Basic):
110            def __init__(self, name):
111                self.name = name
112
113    def test_overlapping_attribute_error(self):
114        place, Transition, place_input, Place, transition = (
115            self.tables.place,
116            self.classes.Transition,
117            self.tables.place_input,
118            self.classes.Place,
119            self.tables.transition,
120        )
121
122        self.mapper_registry.map_imperatively(
123            Place,
124            place,
125            properties={
126                "transitions": relationship(
127                    Transition, secondary=place_input, backref="places"
128                )
129            },
130        )
131        self.mapper_registry.map_imperatively(
132            Transition,
133            transition,
134            properties={
135                "places": relationship(
136                    Place, secondary=place_input, backref="transitions"
137                )
138            },
139        )
140        assert_raises_message(
141            sa.exc.ArgumentError,
142            "property of that name exists",
143            sa.orm.configure_mappers,
144        )
145
146    def test_self_referential_roundtrip(self):
147
148        place, Place, place_place = (
149            self.tables.place,
150            self.classes.Place,
151            self.tables.place_place,
152        )
153
154        self.mapper_registry.map_imperatively(
155            Place,
156            place,
157            properties={
158                "places": relationship(
159                    Place,
160                    secondary=place_place,
161                    primaryjoin=place.c.place_id == place_place.c.pl1_id,
162                    secondaryjoin=place.c.place_id == place_place.c.pl2_id,
163                    order_by=place_place.c.pl2_id,
164                )
165            },
166        )
167
168        sess = fixture_session()
169        p1 = Place("place1")
170        p2 = Place("place2")
171        p3 = Place("place3")
172        p4 = Place("place4")
173        p5 = Place("place5")
174        p6 = Place("place6")
175        p7 = Place("place7")
176        sess.add_all((p1, p2, p3, p4, p5, p6, p7))
177        p1.places.append(p2)
178        p1.places.append(p3)
179        p5.places.append(p6)
180        p6.places.append(p1)
181        p7.places.append(p1)
182        p1.places.append(p5)
183        p4.places.append(p3)
184        p3.places.append(p4)
185        sess.commit()
186
187        eq_(p1.places, [p2, p3, p5])
188        eq_(p5.places, [p6])
189        eq_(p7.places, [p1])
190        eq_(p6.places, [p1])
191        eq_(p4.places, [p3])
192        eq_(p3.places, [p4])
193        eq_(p2.places, [])
194
195    def test_self_referential_bidirectional_mutation(self):
196        place, Place, place_place = (
197            self.tables.place,
198            self.classes.Place,
199            self.tables.place_place,
200        )
201
202        self.mapper_registry.map_imperatively(
203            Place,
204            place,
205            properties={
206                "child_places": relationship(
207                    Place,
208                    secondary=place_place,
209                    primaryjoin=place.c.place_id == place_place.c.pl1_id,
210                    secondaryjoin=place.c.place_id == place_place.c.pl2_id,
211                    order_by=place_place.c.pl2_id,
212                    backref="parent_places",
213                )
214            },
215        )
216
217        sess = fixture_session()
218        p1 = Place("place1")
219        p2 = Place("place2")
220        p2.parent_places = [p1]
221        sess.add_all([p1, p2])
222        p1.parent_places.append(p2)
223        sess.commit()
224
225        assert p1 in p2.parent_places
226        assert p2 in p1.parent_places
227
228    def test_joinedload_on_double(self):
229        """test that a mapper can have two eager relationships to the same table, via
230        two different association tables.  aliases are required."""
231
232        (
233            place_input,
234            transition,
235            Transition,
236            PlaceThingy,
237            place,
238            place_thingy,
239            Place,
240            place_output,
241        ) = (
242            self.tables.place_input,
243            self.tables.transition,
244            self.classes.Transition,
245            self.classes.PlaceThingy,
246            self.tables.place,
247            self.tables.place_thingy,
248            self.classes.Place,
249            self.tables.place_output,
250        )
251
252        self.mapper_registry.map_imperatively(PlaceThingy, place_thingy)
253        self.mapper_registry.map_imperatively(
254            Place,
255            place,
256            properties={"thingies": relationship(PlaceThingy, lazy="joined")},
257        )
258
259        self.mapper_registry.map_imperatively(
260            Transition,
261            transition,
262            properties=dict(
263                inputs=relationship(Place, place_output, lazy="joined"),
264                outputs=relationship(Place, place_input, lazy="joined"),
265            ),
266        )
267
268        tran = Transition("transition1")
269        tran.inputs.append(Place("place1"))
270        tran.outputs.append(Place("place2"))
271        tran.outputs.append(Place("place3"))
272        sess = fixture_session()
273        sess.add(tran)
274        sess.commit()
275
276        r = sess.query(Transition).all()
277        self.assert_unordered_result(
278            r,
279            Transition,
280            {
281                "name": "transition1",
282                "inputs": (Place, [{"name": "place1"}]),
283                "outputs": (Place, [{"name": "place2"}, {"name": "place3"}]),
284            },
285        )
286
287    def test_bidirectional(self):
288        place_input, transition, Transition, Place, place, place_output = (
289            self.tables.place_input,
290            self.tables.transition,
291            self.classes.Transition,
292            self.classes.Place,
293            self.tables.place,
294            self.tables.place_output,
295        )
296
297        self.mapper_registry.map_imperatively(Place, place)
298        self.mapper_registry.map_imperatively(
299            Transition,
300            transition,
301            properties=dict(
302                inputs=relationship(
303                    Place,
304                    place_output,
305                    backref=backref(
306                        "inputs", order_by=transition.c.transition_id
307                    ),
308                    order_by=Place.place_id,
309                ),
310                outputs=relationship(
311                    Place,
312                    place_input,
313                    backref=backref(
314                        "outputs", order_by=transition.c.transition_id
315                    ),
316                    order_by=Place.place_id,
317                ),
318            ),
319        )
320
321        t1 = Transition("transition1")
322        t2 = Transition("transition2")
323        t3 = Transition("transition3")
324        p1 = Place("place1")
325        p2 = Place("place2")
326        p3 = Place("place3")
327
328        sess = fixture_session()
329        sess.add_all([p3, p1, t1, t2, p2, t3])
330
331        t1.inputs.append(p1)
332        t1.inputs.append(p2)
333        t1.outputs.append(p3)
334        t2.inputs.append(p1)
335        p2.inputs.append(t2)
336        p3.inputs.append(t2)
337        p1.outputs.append(t1)
338        sess.commit()
339
340        self.assert_result(
341            [t1],
342            Transition,
343            {"outputs": (Place, [{"name": "place3"}, {"name": "place1"}])},
344        )
345        self.assert_result(
346            [p2],
347            Place,
348            {
349                "inputs": (
350                    Transition,
351                    [{"name": "transition1"}, {"name": "transition2"}],
352                )
353            },
354        )
355
356    @testing.requires.updateable_autoincrement_pks
357    @testing.requires.sane_multi_rowcount
358    def test_stale_conditions(self):
359        Place, Transition, place_input, place, transition = (
360            self.classes.Place,
361            self.classes.Transition,
362            self.tables.place_input,
363            self.tables.place,
364            self.tables.transition,
365        )
366
367        self.mapper_registry.map_imperatively(
368            Place,
369            place,
370            properties={
371                "transitions": relationship(
372                    Transition, secondary=place_input, passive_updates=False
373                )
374            },
375        )
376        self.mapper_registry.map_imperatively(Transition, transition)
377
378        p1 = Place("place1")
379        t1 = Transition("t1")
380        p1.transitions.append(t1)
381        sess = fixture_session()
382        sess.add_all([p1, t1])
383        sess.commit()
384
385        p1.place_id
386        p1.transitions
387
388        sess.execute(place_input.delete())
389        p1.place_id = 7
390
391        assert_raises_message(
392            orm_exc.StaleDataError,
393            r"UPDATE statement on table 'place_input' expected to "
394            r"update 1 row\(s\); Only 0 were matched.",
395            sess.commit,
396        )
397        sess.rollback()
398
399        p1.place_id
400        p1.transitions
401        sess.execute(place_input.delete())
402        p1.transitions.remove(t1)
403        assert_raises_message(
404            orm_exc.StaleDataError,
405            r"DELETE statement on table 'place_input' expected to "
406            r"delete 1 row\(s\); Only 0 were matched.",
407            sess.commit,
408        )
409
410
411class AssortedPersistenceTests(fixtures.MappedTest):
412    @classmethod
413    def define_tables(cls, metadata):
414        Table(
415            "left",
416            metadata,
417            Column(
418                "id", Integer, primary_key=True, test_needs_autoincrement=True
419            ),
420            Column("data", String(30)),
421        )
422
423        Table(
424            "right",
425            metadata,
426            Column(
427                "id", Integer, primary_key=True, test_needs_autoincrement=True
428            ),
429            Column("data", String(30)),
430        )
431
432        Table(
433            "secondary",
434            metadata,
435            Column(
436                "left_id", Integer, ForeignKey("left.id"), primary_key=True
437            ),
438            Column(
439                "right_id", Integer, ForeignKey("right.id"), primary_key=True
440            ),
441        )
442
443    @classmethod
444    def setup_classes(cls):
445        class A(cls.Comparable):
446            pass
447
448        class B(cls.Comparable):
449            pass
450
451    def _standard_bidirectional_fixture(self):
452        left, secondary, right = (
453            self.tables.left,
454            self.tables.secondary,
455            self.tables.right,
456        )
457        A, B = self.classes.A, self.classes.B
458        self.mapper_registry.map_imperatively(
459            A,
460            left,
461            properties={
462                "bs": relationship(
463                    B, secondary=secondary, backref="as", order_by=right.c.id
464                )
465            },
466        )
467        self.mapper_registry.map_imperatively(B, right)
468
469    def _bidirectional_onescalar_fixture(self):
470        left, secondary, right = (
471            self.tables.left,
472            self.tables.secondary,
473            self.tables.right,
474        )
475        A, B = self.classes.A, self.classes.B
476        self.mapper_registry.map_imperatively(
477            A,
478            left,
479            properties={
480                "bs": relationship(
481                    B,
482                    secondary=secondary,
483                    backref=backref("a", uselist=False),
484                    order_by=right.c.id,
485                )
486            },
487        )
488        self.mapper_registry.map_imperatively(B, right)
489
490    def test_session_delete(self):
491        self._standard_bidirectional_fixture()
492        A, B = self.classes.A, self.classes.B
493        secondary = self.tables.secondary
494
495        sess = fixture_session()
496        sess.add_all(
497            [A(data="a1", bs=[B(data="b1")]), A(data="a2", bs=[B(data="b2")])]
498        )
499        sess.commit()
500
501        a1 = sess.query(A).filter_by(data="a1").one()
502        sess.delete(a1)
503        sess.flush()
504        eq_(sess.query(secondary).count(), 1)
505
506        a2 = sess.query(A).filter_by(data="a2").one()
507        sess.delete(a2)
508        sess.flush()
509        eq_(sess.query(secondary).count(), 0)
510
511    def test_remove_scalar(self):
512        # test setting a uselist=False to None
513        self._bidirectional_onescalar_fixture()
514        A, B = self.classes.A, self.classes.B
515        secondary = self.tables.secondary
516
517        sess = fixture_session()
518        sess.add_all([A(data="a1", bs=[B(data="b1"), B(data="b2")])])
519        sess.commit()
520
521        a1 = sess.query(A).filter_by(data="a1").one()
522        b2 = sess.query(B).filter_by(data="b2").one()
523        assert b2.a is a1
524
525        b2.a = None
526        sess.commit()
527
528        eq_(a1.bs, [B(data="b1")])
529        eq_(b2.a, None)
530        eq_(sess.query(secondary).count(), 1)
531