1import sqlalchemy as sa
2from sqlalchemy import event
3from sqlalchemy import ForeignKey
4from sqlalchemy import inspect
5from sqlalchemy import Integer
6from sqlalchemy import Sequence
7from sqlalchemy import String
8from sqlalchemy import testing
9from sqlalchemy.orm import attributes
10from sqlalchemy.orm import backref
11from sqlalchemy.orm import close_all_sessions
12from sqlalchemy.orm import create_session
13from sqlalchemy.orm import exc as orm_exc
14from sqlalchemy.orm import joinedload
15from sqlalchemy.orm import make_transient
16from sqlalchemy.orm import make_transient_to_detached
17from sqlalchemy.orm import mapper
18from sqlalchemy.orm import object_session
19from sqlalchemy.orm import relationship
20from sqlalchemy.orm import Session
21from sqlalchemy.orm import sessionmaker
22from sqlalchemy.orm import was_deleted
23from sqlalchemy.testing import assert_raises
24from sqlalchemy.testing import assert_raises_message
25from sqlalchemy.testing import assertions
26from sqlalchemy.testing import config
27from sqlalchemy.testing import engines
28from sqlalchemy.testing import eq_
29from sqlalchemy.testing import fixtures
30from sqlalchemy.testing import is_
31from sqlalchemy.testing import is_true
32from sqlalchemy.testing import mock
33from sqlalchemy.testing import pickleable
34from sqlalchemy.testing.schema import Column
35from sqlalchemy.testing.schema import Table
36from sqlalchemy.testing.util import gc_collect
37from sqlalchemy.util import pickle
38from sqlalchemy.util.compat import inspect_getfullargspec
39from test.orm import _fixtures
40
41
42class ExecutionTest(_fixtures.FixtureTest):
43    run_inserts = None
44    __backend__ = True
45
46    @testing.requires.sequences
47    def test_sequence_execute(self):
48        seq = Sequence("some_sequence")
49        seq.create(testing.db)
50        try:
51            sess = create_session(bind=testing.db)
52            eq_(sess.execute(seq), 1)
53        finally:
54            seq.drop(testing.db)
55
56    def test_textual_execute(self):
57        """test that Session.execute() converts to text()"""
58
59        users = self.tables.users
60
61        sess = create_session(bind=self.metadata.bind)
62        users.insert().execute(id=7, name="jack")
63
64        # use :bindparam style
65        eq_(
66            sess.execute(
67                "select * from users where id=:id", {"id": 7}
68            ).fetchall(),
69            [(7, "jack")],
70        )
71
72        # use :bindparam style
73        eq_(sess.scalar("select id from users where id=:id", {"id": 7}), 7)
74
75    def test_parameter_execute(self):
76        users = self.tables.users
77        sess = Session(bind=testing.db)
78        sess.execute(
79            users.insert(), [{"id": 7, "name": "u7"}, {"id": 8, "name": "u8"}]
80        )
81        sess.execute(users.insert(), {"id": 9, "name": "u9"})
82        eq_(
83            sess.execute(
84                sa.select([users.c.id]).order_by(users.c.id)
85            ).fetchall(),
86            [(7,), (8,), (9,)],
87        )
88
89
90class TransScopingTest(_fixtures.FixtureTest):
91    run_inserts = None
92    __prefer_requires__ = ("independent_connections",)
93
94    def test_no_close_on_flush(self):
95        """Flush() doesn't close a connection the session didn't open"""
96
97        User, users = self.classes.User, self.tables.users
98
99        c = testing.db.connect()
100        c.execute("select * from users")
101
102        mapper(User, users)
103        s = create_session(bind=c)
104        s.add(User(name="first"))
105        s.flush()
106        c.execute("select * from users")
107
108    def test_close(self):
109        """close() doesn't close a connection the session didn't open"""
110
111        User, users = self.classes.User, self.tables.users
112
113        c = testing.db.connect()
114        c.execute("select * from users")
115
116        mapper(User, users)
117        s = create_session(bind=c)
118        s.add(User(name="first"))
119        s.flush()
120        c.execute("select * from users")
121        s.close()
122        c.execute("select * from users")
123
124    @testing.requires.independent_connections
125    @engines.close_open_connections
126    def test_transaction(self):
127        User, users = self.classes.User, self.tables.users
128
129        mapper(User, users)
130        conn1 = testing.db.connect()
131        conn2 = testing.db.connect()
132
133        sess = create_session(autocommit=False, bind=conn1)
134        u = User(name="x")
135        sess.add(u)
136        sess.flush()
137        assert conn1.execute("select count(1) from users").scalar() == 1
138        assert conn2.execute("select count(1) from users").scalar() == 0
139        sess.commit()
140        assert conn1.execute("select count(1) from users").scalar() == 1
141
142        assert (
143            testing.db.connect().execute("select count(1) from users").scalar()
144            == 1
145        )
146        sess.close()
147
148
149class SessionUtilTest(_fixtures.FixtureTest):
150    run_inserts = None
151
152    def test_close_all_sessions(self):
153        users, User = self.tables.users, self.classes.User
154
155        mapper(User, users)
156
157        s1 = Session()
158        u1 = User()
159        s1.add(u1)
160
161        s2 = Session()
162        u2 = User()
163        s2.add(u2)
164
165        assert u1 in s1
166        assert u2 in s2
167
168        close_all_sessions()
169
170        assert u1 not in s1
171        assert u2 not in s2
172
173    def test_session_close_all_deprecated(self):
174        users, User = self.tables.users, self.classes.User
175
176        mapper(User, users)
177
178        s1 = Session()
179        u1 = User()
180        s1.add(u1)
181
182        s2 = Session()
183        u2 = User()
184        s2.add(u2)
185
186        assert u1 in s1
187        assert u2 in s2
188
189        with assertions.expect_deprecated(
190            r"The Session.close_all\(\) method is deprecated and will "
191            "be removed in a future release. "
192        ):
193            Session.close_all()
194
195        assert u1 not in s1
196        assert u2 not in s2
197
198    def test_object_session_raises(self):
199        User = self.classes.User
200
201        assert_raises(orm_exc.UnmappedInstanceError, object_session, object())
202
203        assert_raises(orm_exc.UnmappedInstanceError, object_session, User())
204
205    def test_make_transient(self):
206        users, User = self.tables.users, self.classes.User
207
208        mapper(User, users)
209        sess = create_session()
210        sess.add(User(name="test"))
211        sess.flush()
212
213        u1 = sess.query(User).first()
214        make_transient(u1)
215        assert u1 not in sess
216        sess.add(u1)
217        assert u1 in sess.new
218
219        u1 = sess.query(User).first()
220        sess.expunge(u1)
221        make_transient(u1)
222        sess.add(u1)
223        assert u1 in sess.new
224
225        # test expired attributes
226        # get unexpired
227        u1 = sess.query(User).first()
228        sess.expire(u1)
229        make_transient(u1)
230        assert u1.id is None
231        assert u1.name is None
232
233        # works twice
234        make_transient(u1)
235
236        sess.close()
237
238        u1.name = "test2"
239        sess.add(u1)
240        sess.flush()
241        assert u1 in sess
242        sess.delete(u1)
243        sess.flush()
244        assert u1 not in sess
245
246        assert_raises(sa.exc.InvalidRequestError, sess.add, u1)
247        make_transient(u1)
248        sess.add(u1)
249        sess.flush()
250        assert u1 in sess
251
252    def test_make_transient_plus_rollback(self):
253        # test for [ticket:2182]
254        users, User = self.tables.users, self.classes.User
255
256        mapper(User, users)
257        sess = Session()
258        u1 = User(name="test")
259        sess.add(u1)
260        sess.commit()
261
262        sess.delete(u1)
263        sess.flush()
264        make_transient(u1)
265        sess.rollback()
266        assert attributes.instance_state(u1).transient
267
268    def test_make_transient_to_detached(self):
269        users, User = self.tables.users, self.classes.User
270
271        mapper(User, users)
272        sess = Session()
273        u1 = User(id=1, name="test")
274        sess.add(u1)
275        sess.commit()
276        sess.close()
277
278        u2 = User(id=1)
279        make_transient_to_detached(u2)
280        assert "id" in u2.__dict__
281        sess.add(u2)
282        eq_(u2.name, "test")
283
284    def test_make_transient_to_detached_no_session_allowed(self):
285        users, User = self.tables.users, self.classes.User
286
287        mapper(User, users)
288        sess = Session()
289        u1 = User(id=1, name="test")
290        sess.add(u1)
291        assert_raises_message(
292            sa.exc.InvalidRequestError,
293            "Given object must be transient",
294            make_transient_to_detached,
295            u1,
296        )
297
298    def test_make_transient_to_detached_no_key_allowed(self):
299        users, User = self.tables.users, self.classes.User
300
301        mapper(User, users)
302        sess = Session()
303        u1 = User(id=1, name="test")
304        sess.add(u1)
305        sess.commit()
306        sess.expunge(u1)
307        assert_raises_message(
308            sa.exc.InvalidRequestError,
309            "Given object must be transient",
310            make_transient_to_detached,
311            u1,
312        )
313
314
315class SessionStateTest(_fixtures.FixtureTest):
316    run_inserts = None
317
318    __prefer_requires__ = ("independent_connections",)
319
320    def test_info(self):
321        s = Session()
322        eq_(s.info, {})
323
324        maker = sessionmaker(info={"global": True, "s1": 5})
325
326        s1 = maker()
327        s2 = maker(info={"s1": 6, "s2": True})
328
329        eq_(s1.info, {"global": True, "s1": 5})
330        eq_(s2.info, {"global": True, "s1": 6, "s2": True})
331        s2.info["global"] = False
332        s2.info["s1"] = 7
333
334        s3 = maker()
335        eq_(s3.info, {"global": True, "s1": 5})
336
337        maker2 = sessionmaker()
338        s4 = maker2(info={"s4": 8})
339        eq_(s4.info, {"s4": 8})
340
341    @testing.requires.independent_connections
342    @engines.close_open_connections
343    def test_autoflush(self):
344        User, users = self.classes.User, self.tables.users
345
346        bind = self.metadata.bind
347        mapper(User, users)
348        conn1 = bind.connect()
349        conn2 = bind.connect()
350
351        sess = create_session(bind=conn1, autocommit=False, autoflush=True)
352        u = User()
353        u.name = "ed"
354        sess.add(u)
355        u2 = sess.query(User).filter_by(name="ed").one()
356        assert u2 is u
357        eq_(conn1.execute("select count(1) from users").scalar(), 1)
358        eq_(conn2.execute("select count(1) from users").scalar(), 0)
359        sess.commit()
360        eq_(conn1.execute("select count(1) from users").scalar(), 1)
361        eq_(bind.connect().execute("select count(1) from users").scalar(), 1)
362        sess.close()
363
364    def test_with_no_autoflush(self):
365        User, users = self.classes.User, self.tables.users
366
367        mapper(User, users)
368        sess = Session()
369
370        u = User()
371        u.name = "ed"
372        sess.add(u)
373
374        def go(obj):
375            assert u not in sess.query(User).all()
376
377        testing.run_as_contextmanager(sess.no_autoflush, go)
378        assert u in sess.new
379        assert u in sess.query(User).all()
380        assert u not in sess.new
381
382    def test_with_no_autoflush_after_exception(self):
383        sess = Session(autoflush=True)
384
385        assert_raises(
386            ZeroDivisionError,
387            testing.run_as_contextmanager,
388            sess.no_autoflush,
389            lambda obj: 1 / 0,
390        )
391
392        is_true(sess.autoflush)
393
394    def test_autoflush_exception_addition(self):
395        User, users = self.classes.User, self.tables.users
396        Address, addresses = self.classes.Address, self.tables.addresses
397        mapper(User, users, properties={"addresses": relationship(Address)})
398        mapper(Address, addresses)
399
400        s = Session(testing.db)
401
402        u1 = User(name="first")
403
404        s.add(u1)
405        s.commit()
406
407        u1.addresses.append(Address(email=None))
408
409        # will raise for null email address
410        assert_raises_message(
411            sa.exc.DBAPIError,
412            ".*raised as a result of Query-invoked autoflush; consider using "
413            "a session.no_autoflush block.*",
414            s.query(User).first,
415        )
416
417    def test_deleted_flag(self):
418        users, User = self.tables.users, self.classes.User
419
420        mapper(User, users)
421
422        sess = sessionmaker()()
423
424        u1 = User(name="u1")
425        sess.add(u1)
426        sess.commit()
427
428        sess.delete(u1)
429        sess.flush()
430        assert u1 not in sess
431        assert_raises(sa.exc.InvalidRequestError, sess.add, u1)
432        sess.rollback()
433        assert u1 in sess
434
435        sess.delete(u1)
436        sess.commit()
437        assert u1 not in sess
438        assert_raises(sa.exc.InvalidRequestError, sess.add, u1)
439
440        make_transient(u1)
441        sess.add(u1)
442        sess.commit()
443
444        eq_(sess.query(User).count(), 1)
445
446    @testing.requires.sane_rowcount
447    def test_deleted_adds_to_imap_unconditionally(self):
448        users, User = self.tables.users, self.classes.User
449
450        mapper(User, users)
451
452        sess = Session()
453        u1 = User(name="u1")
454        sess.add(u1)
455        sess.commit()
456
457        sess.delete(u1)
458        sess.flush()
459
460        # object is not in session
461        assert u1 not in sess
462
463        # but it *is* attached
464        assert u1._sa_instance_state.session_id == sess.hash_key
465
466        # mark as deleted again
467        sess.delete(u1)
468
469        # in the session again
470        assert u1 in sess
471
472        # commit proceeds w/ warning
473        with assertions.expect_warnings(
474            "DELETE statement on table 'users' "
475            r"expected to delete 1 row\(s\); 0 were matched."
476        ):
477            sess.commit()
478
479    @testing.requires.independent_connections
480    @engines.close_open_connections
481    def test_autoflush_unbound(self):
482        User, users = self.classes.User, self.tables.users
483
484        mapper(User, users)
485        try:
486            sess = create_session(autocommit=False, autoflush=True)
487            u = User()
488            u.name = "ed"
489            sess.add(u)
490            u2 = sess.query(User).filter_by(name="ed").one()
491            assert u2 is u
492            assert (
493                sess.execute(
494                    "select count(1) from users", mapper=User
495                ).scalar()
496                == 1
497            )
498            assert (
499                testing.db.connect()
500                .execute("select count(1) from users")
501                .scalar()
502                == 0
503            )
504            sess.commit()
505            assert (
506                sess.execute(
507                    "select count(1) from users", mapper=User
508                ).scalar()
509                == 1
510            )
511            assert (
512                testing.db.connect()
513                .execute("select count(1) from users")
514                .scalar()
515                == 1
516            )
517            sess.close()
518        except Exception:
519            sess.rollback()
520            raise
521
522    @engines.close_open_connections
523    def test_autoflush_2(self):
524        User, users = self.classes.User, self.tables.users
525
526        mapper(User, users)
527        conn1 = testing.db.connect()
528        sess = create_session(bind=conn1, autocommit=False, autoflush=True)
529        u = User()
530        u.name = "ed"
531        sess.add(u)
532        sess.commit()
533        assert conn1.execute("select count(1) from users").scalar() == 1
534        assert (
535            testing.db.connect().execute("select count(1) from users").scalar()
536            == 1
537        )
538        sess.commit()
539
540    def test_autocommit_doesnt_raise_on_pending(self):
541        User, users = self.classes.User, self.tables.users
542
543        mapper(User, users)
544        session = create_session(autocommit=True)
545
546        session.add(User(name="ed"))
547
548        session.begin()
549        session.flush()
550        session.commit()
551
552    def test_active_flag(self):
553        sess = create_session(bind=config.db, autocommit=True)
554        assert not sess.is_active
555        sess.begin()
556        assert sess.is_active
557        sess.rollback()
558        assert not sess.is_active
559
560    @engines.close_open_connections
561    def test_add_delete(self):
562        User, Address, addresses, users = (
563            self.classes.User,
564            self.classes.Address,
565            self.tables.addresses,
566            self.tables.users,
567        )
568
569        s = create_session()
570        mapper(
571            User,
572            users,
573            properties={
574                "addresses": relationship(Address, cascade="all, delete")
575            },
576        )
577        mapper(Address, addresses)
578
579        user = User(name="u1")
580
581        assert_raises_message(
582            sa.exc.InvalidRequestError, "is not persisted", s.delete, user
583        )
584
585        s.add(user)
586        s.flush()
587        user = s.query(User).one()
588        s.expunge(user)
589        assert user not in s
590
591        # modify outside of session, assert changes remain/get saved
592        user.name = "fred"
593        s.add(user)
594        assert user in s
595        assert user in s.dirty
596        s.flush()
597        s.expunge_all()
598        assert s.query(User).count() == 1
599        user = s.query(User).one()
600        assert user.name == "fred"
601
602        # ensure its not dirty if no changes occur
603        s.expunge_all()
604        assert user not in s
605        s.add(user)
606        assert user in s
607        assert user not in s.dirty
608
609        s2 = create_session()
610        assert_raises_message(
611            sa.exc.InvalidRequestError,
612            "is already attached to session",
613            s2.delete,
614            user,
615        )
616        u2 = s2.query(User).get(user.id)
617        s2.expunge(u2)
618        assert_raises_message(
619            sa.exc.InvalidRequestError,
620            "another instance .* is already present",
621            s.delete,
622            u2,
623        )
624        s.expire(user)
625        s.expunge(user)
626        assert user not in s
627        s.delete(user)
628        assert user in s
629
630        s.flush()
631        assert user not in s
632        assert s.query(User).count() == 0
633
634    def test_already_attached(self):
635        User = self.classes.User
636        users = self.tables.users
637        mapper(User, users)
638
639        s1 = Session()
640        s2 = Session()
641
642        u1 = User(id=1, name="u1")
643        make_transient_to_detached(u1)  # shorthand for actually persisting it
644        s1.add(u1)
645
646        assert_raises_message(
647            sa.exc.InvalidRequestError,
648            "Object '<User.*?>' is already attached to session",
649            s2.add,
650            u1,
651        )
652        assert u1 not in s2
653        assert not s2.identity_map.keys()
654
655    def test_identity_conflict(self):
656        users, User = self.tables.users, self.classes.User
657
658        mapper(User, users)
659        for s in (create_session(), create_session()):
660            users.delete().execute()
661            u1 = User(name="ed")
662            s.add(u1)
663            s.flush()
664            s.expunge(u1)
665            u2 = s.query(User).first()
666            s.expunge(u2)
667            s.identity_map.add(sa.orm.attributes.instance_state(u1))
668
669            assert_raises_message(
670                sa.exc.InvalidRequestError,
671                "Can't attach instance <User.*?>; another instance "
672                "with key .*? is already "
673                "present in this session.",
674                s.identity_map.add,
675                sa.orm.attributes.instance_state(u2),
676            )
677
678    def test_internal_identity_conflict_warning_weak(self):
679        self._test_internal_identity_conflict_warning(True)
680
681    def test_internal_identity_conflict_warning_strong(self):
682        self._test_internal_identity_conflict_warning(False)
683
684    def _test_internal_identity_conflict_warning(self, weak_identity_map):
685        # test for issue #4890
686        # see also test_naturalpks::ReversePKsTest::test_reverse
687        users, User = self.tables.users, self.classes.User
688        addresses, Address = self.tables.addresses, self.classes.Address
689
690        mapper(
691            User,
692            users,
693            properties={"addresses": relationship(Address, backref="user")},
694        )
695        mapper(Address, addresses)
696
697        with testing.expect_deprecated():
698            session = Session(weak_identity_map=weak_identity_map)
699
700        @event.listens_for(session, "after_flush")
701        def load_collections(session, flush_context):
702            for target in set(session.new).union(session.dirty):
703                if isinstance(target, User):
704                    target.addresses
705
706        u1 = User(name="u1")
707        a1 = Address(email_address="e1", user=u1)
708        session.add_all([u1, a1])
709        session.flush()
710
711        session.expire_all()
712
713        # create new Address via backref, so that u1.addresses remains
714        # expired and a2 is in pending mutations
715        a2 = Address(email_address="e2", user=u1)
716        assert "addresses" not in inspect(u1).dict
717        assert a2 in inspect(u1)._pending_mutations["addresses"].added_items
718
719        with assertions.expect_warnings(
720            r"Identity map already had an identity "
721            r"for \(.*Address.*\), replacing"
722        ):
723            session.flush()
724
725    def test_pickled_update(self):
726        users, User = self.tables.users, pickleable.User
727
728        mapper(User, users)
729        sess1 = create_session()
730        sess2 = create_session()
731        u1 = User(name="u1")
732        sess1.add(u1)
733        assert_raises_message(
734            sa.exc.InvalidRequestError,
735            "already attached to session",
736            sess2.add,
737            u1,
738        )
739        u2 = pickle.loads(pickle.dumps(u1))
740        sess2.add(u2)
741
742    def test_duplicate_update(self):
743        users, User = self.tables.users, self.classes.User
744
745        mapper(User, users)
746        Session = sessionmaker()
747        sess = Session()
748
749        u1 = User(name="u1")
750        sess.add(u1)
751        sess.flush()
752        assert u1.id is not None
753
754        sess.expunge(u1)
755
756        assert u1 not in sess
757        assert Session.object_session(u1) is None
758
759        u2 = sess.query(User).get(u1.id)
760        assert u2 is not None and u2 is not u1
761        assert u2 in sess
762
763        assert_raises_message(
764            sa.exc.InvalidRequestError,
765            "Can't attach instance <User.*?>; another instance "
766            "with key .*? is already "
767            "present in this session.",
768            sess.add,
769            u1,
770        )
771
772        sess.expunge(u2)
773        assert u2 not in sess
774        assert Session.object_session(u2) is None
775
776        u1.name = "John"
777        u2.name = "Doe"
778
779        sess.add(u1)
780        assert u1 in sess
781        assert Session.object_session(u1) is sess
782
783        sess.flush()
784
785        sess.expunge_all()
786
787        u3 = sess.query(User).get(u1.id)
788        assert u3 is not u1 and u3 is not u2 and u3.name == u1.name
789
790    def test_no_double_save(self):
791        users = self.tables.users
792
793        sess = create_session()
794
795        class Foo(object):
796            def __init__(self):
797                sess.add(self)
798
799        class Bar(Foo):
800            def __init__(self):
801                sess.add(self)
802                Foo.__init__(self)
803
804        mapper(Foo, users)
805        mapper(Bar, users)
806
807        b = Bar()
808        assert b in sess
809        assert len(list(sess)) == 1
810
811    def test_identity_map_mutate(self):
812        users, User = self.tables.users, self.classes.User
813
814        mapper(User, users)
815
816        sess = Session()
817
818        sess.add_all([User(name="u1"), User(name="u2"), User(name="u3")])
819        sess.commit()
820
821        # TODO: what are we testing here ?   that iteritems() can
822        # withstand a change?  should this be
823        # more directly attempting to manipulate the identity_map ?
824        u1, u2, u3 = sess.query(User).all()
825        for i, (key, value) in enumerate(iter(sess.identity_map.items())):
826            if i == 2:
827                del u3
828                gc_collect()
829
830    def _test_extra_dirty_state(self):
831        users, User = self.tables.users, self.classes.User
832        m = mapper(User, users)
833
834        s = Session()
835
836        @event.listens_for(m, "after_update")
837        def e(mapper, conn, target):
838            sess = object_session(target)
839            for entry in list(sess.identity_map.values()):
840                entry.name = "5"
841
842        a1, a2 = User(name="1"), User(name="2")
843
844        s.add_all([a1, a2])
845        s.commit()
846
847        a1.name = "3"
848        return s, a1, a2
849
850    def test_extra_dirty_state_post_flush_warning(self):
851        s, a1, a2 = self._test_extra_dirty_state()
852        assert_raises_message(
853            sa.exc.SAWarning,
854            "Attribute history events accumulated on 1 previously "
855            "clean instances",
856            s.commit,
857        )
858
859    def test_extra_dirty_state_post_flush_state(self):
860        s, a1, a2 = self._test_extra_dirty_state()
861        canary = []
862
863        @event.listens_for(s, "after_flush_postexec")
864        def e(sess, ctx):
865            canary.append(bool(sess.identity_map._modified))
866
867        @testing.emits_warning("Attribute")
868        def go():
869            s.commit()
870
871        go()
872        eq_(canary, [False])
873
874    def test_deleted_auto_expunged(self):
875        users, User = self.tables.users, self.classes.User
876
877        mapper(User, users)
878        sess = Session()
879        sess.add(User(name="x"))
880        sess.commit()
881
882        u1 = sess.query(User).first()
883        sess.delete(u1)
884
885        assert not was_deleted(u1)
886        sess.flush()
887
888        assert was_deleted(u1)
889        assert u1 not in sess
890        assert object_session(u1) is sess
891        sess.commit()
892
893        assert object_session(u1) is None
894
895    def test_explicit_expunge_pending(self):
896        users, User = self.tables.users, self.classes.User
897
898        mapper(User, users)
899        sess = Session()
900        u1 = User(name="x")
901        sess.add(u1)
902
903        sess.flush()
904        sess.expunge(u1)
905
906        assert u1 not in sess
907        assert object_session(u1) is None
908
909        sess.rollback()
910
911        assert u1 not in sess
912        assert object_session(u1) is None
913
914    def test_explicit_expunge_deleted(self):
915        users, User = self.tables.users, self.classes.User
916
917        mapper(User, users)
918        sess = Session()
919        sess.add(User(name="x"))
920        sess.commit()
921
922        u1 = sess.query(User).first()
923        sess.delete(u1)
924
925        sess.flush()
926
927        assert was_deleted(u1)
928        assert u1 not in sess
929        assert object_session(u1) is sess
930
931        sess.expunge(u1)
932        assert was_deleted(u1)
933        assert u1 not in sess
934        assert object_session(u1) is None
935
936        sess.rollback()
937        assert was_deleted(u1)
938        assert u1 not in sess
939        assert object_session(u1) is None
940
941
942class DeferredRelationshipExpressionTest(_fixtures.FixtureTest):
943    run_inserts = None
944    run_deletes = "each"
945
946    @classmethod
947    def setup_mappers(cls):
948        users, Address, addresses, User = (
949            cls.tables.users,
950            cls.classes.Address,
951            cls.tables.addresses,
952            cls.classes.User,
953        )
954
955        mapper(
956            User,
957            users,
958            properties={"addresses": relationship(Address, backref="user")},
959        )
960        mapper(Address, addresses)
961
962    def test_deferred_expression_unflushed(self):
963        """test that an expression which is dependent on object state is
964        evaluated after the session autoflushes.   This is the lambda
965        inside of strategies.py lazy_clause.
966
967        """
968        User, Address = self.classes("User", "Address")
969
970        sess = create_session(autoflush=True, autocommit=False)
971        u = User(name="ed", addresses=[Address(email_address="foo")])
972        sess.add(u)
973        eq_(
974            sess.query(Address).filter(Address.user == u).one(),
975            Address(email_address="foo"),
976        )
977
978    def test_deferred_expression_obj_was_gced(self):
979        User, Address = self.classes("User", "Address")
980
981        sess = create_session(autoflush=True, autocommit=False)
982        u = User(name="ed", addresses=[Address(email_address="foo")])
983        sess.add(u)
984
985        sess.commit()
986        sess.close()
987        u = sess.query(User).get(u.id)
988        q = sess.query(Address).filter(Address.user == u)
989        del u
990        gc_collect()
991        eq_(q.one(), Address(email_address="foo"))
992
993    def test_deferred_expression_favors_immediate(self):
994        """Test that a deferred expression will return an immediate value
995        if available, rather than invoking after the object is detached
996
997        """
998
999        User, Address = self.classes("User", "Address")
1000
1001        sess = create_session(autoflush=True, autocommit=False)
1002        u = User(name="ed", addresses=[Address(email_address="foo")])
1003        sess.add(u)
1004        sess.commit()
1005
1006        q = sess.query(Address).filter(Address.user == u)
1007        sess.expire(u)
1008        sess.expunge(u)
1009        eq_(q.one(), Address(email_address="foo"))
1010
1011    def test_deferred_expression_obj_was_never_flushed(self):
1012        User, Address = self.classes("User", "Address")
1013
1014        sess = create_session(autoflush=True, autocommit=False)
1015        u = User(name="ed", addresses=[Address(email_address="foo")])
1016
1017        assert_raises_message(
1018            sa.exc.InvalidRequestError,
1019            "Can't resolve value for column users.id on object "
1020            ".User.*.; no value has been set for this column",
1021            (Address.user == u).left.callable,
1022        )
1023
1024        q = sess.query(Address).filter(Address.user == u)
1025        assert_raises_message(
1026            sa.exc.StatementError,
1027            "Can't resolve value for column users.id on object "
1028            ".User.*.; no value has been set for this column",
1029            q.one,
1030        )
1031
1032    def test_deferred_expression_transient_but_manually_set(self):
1033        User, Address = self.classes("User", "Address")
1034
1035        u = User(id=5, name="ed", addresses=[Address(email_address="foo")])
1036
1037        expr = Address.user == u
1038        eq_(expr.left.callable(), 5)
1039
1040    def test_deferred_expression_unflushed_obj_became_detached_unexpired(self):
1041        User, Address = self.classes("User", "Address")
1042
1043        sess = create_session(autoflush=True, autocommit=False)
1044        u = User(name="ed", addresses=[Address(email_address="foo")])
1045
1046        q = sess.query(Address).filter(Address.user == u)
1047
1048        sess.add(u)
1049        sess.flush()
1050
1051        sess.expunge(u)
1052        eq_(q.one(), Address(email_address="foo"))
1053
1054    def test_deferred_expression_unflushed_obj_became_detached_expired(self):
1055        User, Address = self.classes("User", "Address")
1056
1057        sess = create_session(autoflush=True, autocommit=False)
1058        u = User(name="ed", addresses=[Address(email_address="foo")])
1059
1060        q = sess.query(Address).filter(Address.user == u)
1061
1062        sess.add(u)
1063        sess.flush()
1064
1065        sess.expire(u)
1066        sess.expunge(u)
1067        eq_(q.one(), Address(email_address="foo"))
1068
1069    def test_deferred_expr_unflushed_obj_became_detached_expired_by_key(self):
1070        User, Address = self.classes("User", "Address")
1071
1072        sess = create_session(autoflush=True, autocommit=False)
1073        u = User(name="ed", addresses=[Address(email_address="foo")])
1074
1075        q = sess.query(Address).filter(Address.user == u)
1076
1077        sess.add(u)
1078        sess.flush()
1079
1080        sess.expire(u, ["id"])
1081        sess.expunge(u)
1082        eq_(q.one(), Address(email_address="foo"))
1083
1084    def test_deferred_expression_expired_obj_became_detached_expired(self):
1085        User, Address = self.classes("User", "Address")
1086
1087        sess = create_session(
1088            autoflush=True, autocommit=False, expire_on_commit=True
1089        )
1090        u = User(name="ed", addresses=[Address(email_address="foo")])
1091
1092        sess.add(u)
1093        sess.commit()
1094
1095        assert "id" not in u.__dict__  # it's expired
1096
1097        # should not emit SQL
1098        def go():
1099            Address.user == u
1100
1101        self.assert_sql_count(testing.db, go, 0)
1102
1103        # create the expression here, but note we weren't tracking 'id'
1104        # yet so we don't have the old value
1105        q = sess.query(Address).filter(Address.user == u)
1106
1107        sess.expunge(u)
1108        assert_raises_message(
1109            sa.exc.StatementError,
1110            "Can't resolve value for column users.id on object "
1111            ".User.*.; the object is detached and the value was expired",
1112            q.one,
1113        )
1114
1115
1116class SessionStateWFixtureTest(_fixtures.FixtureTest):
1117    __backend__ = True
1118
1119    def test_autoflush_rollback(self):
1120        Address, addresses, users, User = (
1121            self.classes.Address,
1122            self.tables.addresses,
1123            self.tables.users,
1124            self.classes.User,
1125        )
1126
1127        mapper(Address, addresses)
1128        mapper(User, users, properties={"addresses": relationship(Address)})
1129
1130        sess = create_session(autocommit=False, autoflush=True)
1131        u = sess.query(User).get(8)
1132        newad = Address(email_address="a new address")
1133        u.addresses.append(newad)
1134        u.name = "some new name"
1135        assert u.name == "some new name"
1136        assert len(u.addresses) == 4
1137        assert newad in u.addresses
1138        sess.rollback()
1139        assert u.name == "ed"
1140        assert len(u.addresses) == 3
1141
1142        assert newad not in u.addresses
1143        # pending objects don't get expired
1144        assert newad.email_address == "a new address"
1145
1146    def test_expunge_cascade(self):
1147        Address, addresses, users, User = (
1148            self.classes.Address,
1149            self.tables.addresses,
1150            self.tables.users,
1151            self.classes.User,
1152        )
1153
1154        mapper(Address, addresses)
1155        mapper(
1156            User,
1157            users,
1158            properties={
1159                "addresses": relationship(
1160                    Address,
1161                    backref=backref("user", cascade="all"),
1162                    cascade="all",
1163                )
1164            },
1165        )
1166
1167        session = create_session()
1168        u = session.query(User).filter_by(id=7).one()
1169
1170        # get everything to load in both directions
1171        print([a.user for a in u.addresses])
1172
1173        # then see if expunge fails
1174        session.expunge(u)
1175
1176        assert sa.orm.object_session(u) is None
1177        assert sa.orm.attributes.instance_state(u).session_id is None
1178        for a in u.addresses:
1179            assert sa.orm.object_session(a) is None
1180            assert sa.orm.attributes.instance_state(a).session_id is None
1181
1182
1183class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest):
1184    """Test the instance_state._strong_obj link that it
1185    is present only on persistent/pending objects and never
1186    transient/detached.
1187
1188    """
1189
1190    run_inserts = None
1191
1192    def setup(self):
1193        mapper(self.classes.User, self.tables.users)
1194
1195    def _assert_modified(self, u1):
1196        assert sa.orm.attributes.instance_state(u1).modified
1197
1198    def _assert_not_modified(self, u1):
1199        assert not sa.orm.attributes.instance_state(u1).modified
1200
1201    def _assert_cycle(self, u1):
1202        assert sa.orm.attributes.instance_state(u1)._strong_obj is not None
1203
1204    def _assert_no_cycle(self, u1):
1205        assert sa.orm.attributes.instance_state(u1)._strong_obj is None
1206
1207    def _persistent_fixture(self):
1208        User = self.classes.User
1209        u1 = User()
1210        u1.name = "ed"
1211        sess = Session()
1212        sess.add(u1)
1213        sess.flush()
1214        return sess, u1
1215
1216    def test_transient(self):
1217        User = self.classes.User
1218        u1 = User()
1219        u1.name = "ed"
1220        self._assert_no_cycle(u1)
1221        self._assert_modified(u1)
1222
1223    def test_transient_to_pending(self):
1224        User = self.classes.User
1225        u1 = User()
1226        u1.name = "ed"
1227        self._assert_modified(u1)
1228        self._assert_no_cycle(u1)
1229        sess = Session()
1230        sess.add(u1)
1231        self._assert_cycle(u1)
1232        sess.flush()
1233        self._assert_no_cycle(u1)
1234        self._assert_not_modified(u1)
1235
1236    def test_dirty_persistent_to_detached_via_expunge(self):
1237        sess, u1 = self._persistent_fixture()
1238        u1.name = "edchanged"
1239        self._assert_cycle(u1)
1240        sess.expunge(u1)
1241        self._assert_no_cycle(u1)
1242
1243    def test_dirty_persistent_to_detached_via_close(self):
1244        sess, u1 = self._persistent_fixture()
1245        u1.name = "edchanged"
1246        self._assert_cycle(u1)
1247        sess.close()
1248        self._assert_no_cycle(u1)
1249
1250    def test_clean_persistent_to_detached_via_close(self):
1251        sess, u1 = self._persistent_fixture()
1252        self._assert_no_cycle(u1)
1253        self._assert_not_modified(u1)
1254        sess.close()
1255        u1.name = "edchanged"
1256        self._assert_modified(u1)
1257        self._assert_no_cycle(u1)
1258
1259    def test_detached_to_dirty_deleted(self):
1260        sess, u1 = self._persistent_fixture()
1261        sess.expunge(u1)
1262        u1.name = "edchanged"
1263        self._assert_no_cycle(u1)
1264        sess.delete(u1)
1265        self._assert_cycle(u1)
1266
1267    def test_detached_to_dirty_persistent(self):
1268        sess, u1 = self._persistent_fixture()
1269        sess.expunge(u1)
1270        u1.name = "edchanged"
1271        self._assert_modified(u1)
1272        self._assert_no_cycle(u1)
1273        sess.add(u1)
1274        self._assert_cycle(u1)
1275        self._assert_modified(u1)
1276
1277    def test_detached_to_clean_persistent(self):
1278        sess, u1 = self._persistent_fixture()
1279        sess.expunge(u1)
1280        self._assert_no_cycle(u1)
1281        self._assert_not_modified(u1)
1282        sess.add(u1)
1283        self._assert_no_cycle(u1)
1284        self._assert_not_modified(u1)
1285
1286    def test_move_persistent_clean(self):
1287        sess, u1 = self._persistent_fixture()
1288        sess.close()
1289        s2 = Session()
1290        s2.add(u1)
1291        self._assert_no_cycle(u1)
1292        self._assert_not_modified(u1)
1293
1294    def test_move_persistent_dirty(self):
1295        sess, u1 = self._persistent_fixture()
1296        u1.name = "edchanged"
1297        self._assert_cycle(u1)
1298        self._assert_modified(u1)
1299        sess.close()
1300        self._assert_no_cycle(u1)
1301        s2 = Session()
1302        s2.add(u1)
1303        self._assert_cycle(u1)
1304        self._assert_modified(u1)
1305
1306    @testing.requires.predictable_gc
1307    def test_move_gc_session_persistent_dirty(self):
1308        sess, u1 = self._persistent_fixture()
1309        u1.name = "edchanged"
1310        self._assert_cycle(u1)
1311        self._assert_modified(u1)
1312        del sess
1313        gc_collect()
1314        self._assert_cycle(u1)
1315        s2 = Session()
1316        s2.add(u1)
1317        self._assert_cycle(u1)
1318        self._assert_modified(u1)
1319
1320    def test_persistent_dirty_to_expired(self):
1321        sess, u1 = self._persistent_fixture()
1322        u1.name = "edchanged"
1323        self._assert_cycle(u1)
1324        self._assert_modified(u1)
1325        sess.expire(u1)
1326        self._assert_no_cycle(u1)
1327        self._assert_not_modified(u1)
1328
1329
1330class WeakIdentityMapTest(_fixtures.FixtureTest):
1331    run_inserts = None
1332
1333    @testing.requires.predictable_gc
1334    def test_weakref(self):
1335        """test the weak-referencing identity map, which strongly-
1336        references modified items."""
1337
1338        users, User = self.tables.users, self.classes.User
1339
1340        s = create_session()
1341        mapper(User, users)
1342
1343        s.add(User(name="ed"))
1344        s.flush()
1345        assert not s.dirty
1346
1347        user = s.query(User).one()
1348        del user
1349        gc_collect()
1350        assert len(s.identity_map) == 0
1351
1352        user = s.query(User).one()
1353        user.name = "fred"
1354        del user
1355        gc_collect()
1356        assert len(s.identity_map) == 1
1357        assert len(s.dirty) == 1
1358        assert None not in s.dirty
1359        s.flush()
1360        gc_collect()
1361        assert not s.dirty
1362        assert not s.identity_map
1363
1364        user = s.query(User).one()
1365        assert user.name == "fred"
1366        assert s.identity_map
1367
1368    @testing.requires.predictable_gc
1369    def test_weakref_pickled(self):
1370        users, User = self.tables.users, pickleable.User
1371
1372        s = create_session()
1373        mapper(User, users)
1374
1375        s.add(User(name="ed"))
1376        s.flush()
1377        assert not s.dirty
1378
1379        user = s.query(User).one()
1380        user.name = "fred"
1381        s.expunge(user)
1382
1383        u2 = pickle.loads(pickle.dumps(user))
1384
1385        del user
1386        s.add(u2)
1387
1388        del u2
1389        gc_collect()
1390
1391        assert len(s.identity_map) == 1
1392        assert len(s.dirty) == 1
1393        assert None not in s.dirty
1394        s.flush()
1395        gc_collect()
1396        assert not s.dirty
1397
1398        assert not s.identity_map
1399
1400    @testing.requires.predictable_gc
1401    def test_weakref_with_cycles_o2m(self):
1402        Address, addresses, users, User = (
1403            self.classes.Address,
1404            self.tables.addresses,
1405            self.tables.users,
1406            self.classes.User,
1407        )
1408
1409        s = sessionmaker()()
1410        mapper(
1411            User,
1412            users,
1413            properties={"addresses": relationship(Address, backref="user")},
1414        )
1415        mapper(Address, addresses)
1416        s.add(User(name="ed", addresses=[Address(email_address="ed1")]))
1417        s.commit()
1418
1419        user = s.query(User).options(joinedload(User.addresses)).one()
1420        user.addresses[0].user  # lazyload
1421        eq_(user, User(name="ed", addresses=[Address(email_address="ed1")]))
1422
1423        del user
1424        gc_collect()
1425        assert len(s.identity_map) == 0
1426
1427        user = s.query(User).options(joinedload(User.addresses)).one()
1428        user.addresses[0].email_address = "ed2"
1429        user.addresses[0].user  # lazyload
1430        del user
1431        gc_collect()
1432        assert len(s.identity_map) == 2
1433
1434        s.commit()
1435        user = s.query(User).options(joinedload(User.addresses)).one()
1436        eq_(user, User(name="ed", addresses=[Address(email_address="ed2")]))
1437
1438    @testing.requires.predictable_gc
1439    def test_weakref_with_cycles_o2o(self):
1440        Address, addresses, users, User = (
1441            self.classes.Address,
1442            self.tables.addresses,
1443            self.tables.users,
1444            self.classes.User,
1445        )
1446
1447        s = sessionmaker()()
1448        mapper(
1449            User,
1450            users,
1451            properties={
1452                "address": relationship(Address, backref="user", uselist=False)
1453            },
1454        )
1455        mapper(Address, addresses)
1456        s.add(User(name="ed", address=Address(email_address="ed1")))
1457        s.commit()
1458
1459        user = s.query(User).options(joinedload(User.address)).one()
1460        user.address.user
1461        eq_(user, User(name="ed", address=Address(email_address="ed1")))
1462
1463        del user
1464        gc_collect()
1465        assert len(s.identity_map) == 0
1466
1467        user = s.query(User).options(joinedload(User.address)).one()
1468        user.address.email_address = "ed2"
1469        user.address.user  # lazyload
1470
1471        del user
1472        gc_collect()
1473        assert len(s.identity_map) == 2
1474
1475        s.commit()
1476        user = s.query(User).options(joinedload(User.address)).one()
1477        eq_(user, User(name="ed", address=Address(email_address="ed2")))
1478
1479    def test_auto_detach_on_gc_session(self):
1480        users, User = self.tables.users, self.classes.User
1481
1482        mapper(User, users)
1483
1484        sess = Session()
1485
1486        u1 = User(name="u1")
1487        sess.add(u1)
1488        sess.commit()
1489
1490        # can't add u1 to Session,
1491        # already belongs to u2
1492        s2 = Session()
1493        assert_raises_message(
1494            sa.exc.InvalidRequestError,
1495            r".*is already attached to session",
1496            s2.add,
1497            u1,
1498        )
1499
1500        # garbage collect sess
1501        del sess
1502        gc_collect()
1503
1504        # s2 lets it in now despite u1 having
1505        # session_key
1506        s2.add(u1)
1507        assert u1 in s2
1508
1509    def test_fast_discard_race(self):
1510        # test issue #4068
1511        users, User = self.tables.users, self.classes.User
1512
1513        mapper(User, users)
1514
1515        sess = Session()
1516
1517        u1 = User(name="u1")
1518        sess.add(u1)
1519        sess.commit()
1520
1521        u1_state = u1._sa_instance_state
1522        ref = u1_state.obj
1523        u1_state.obj = lambda: None
1524
1525        u2 = sess.query(User).first()
1526        u1_state._cleanup(ref)
1527
1528        u3 = sess.query(User).first()
1529
1530        is_(u2, u3)
1531
1532        u2_state = u2._sa_instance_state
1533        ref = u2_state.obj
1534        u2_state.obj = lambda: None
1535        u2_state._cleanup(ref)
1536        assert not sess.identity_map.contains_state(u2._sa_instance_state)
1537
1538
1539class IsModifiedTest(_fixtures.FixtureTest):
1540    run_inserts = None
1541
1542    def _default_mapping_fixture(self):
1543        User, Address = self.classes.User, self.classes.Address
1544        users, addresses = self.tables.users, self.tables.addresses
1545        mapper(User, users, properties={"addresses": relationship(Address)})
1546        mapper(Address, addresses)
1547        return User, Address
1548
1549    def test_is_modified(self):
1550        User, Address = self._default_mapping_fixture()
1551
1552        s = create_session()
1553
1554        # save user
1555        u = User(name="fred")
1556        s.add(u)
1557        s.flush()
1558        s.expunge_all()
1559
1560        user = s.query(User).one()
1561        assert user not in s.dirty
1562        assert not s.is_modified(user)
1563        user.name = "fred"
1564        assert user in s.dirty
1565        assert not s.is_modified(user)
1566        user.name = "ed"
1567        assert user in s.dirty
1568        assert s.is_modified(user)
1569        s.flush()
1570        assert user not in s.dirty
1571        assert not s.is_modified(user)
1572
1573        a = Address()
1574        user.addresses.append(a)
1575        assert user in s.dirty
1576        assert s.is_modified(user)
1577        assert not s.is_modified(user, include_collections=False)
1578
1579    def test_is_modified_passive_off(self):
1580        """as of 0.8 no SQL is emitted for is_modified()
1581        regardless of the passive flag"""
1582
1583        User, Address = self._default_mapping_fixture()
1584
1585        s = Session()
1586        u = User(name="fred", addresses=[Address(email_address="foo")])
1587        s.add(u)
1588        s.commit()
1589
1590        u.id
1591
1592        def go():
1593            assert not s.is_modified(u)
1594
1595        self.assert_sql_count(testing.db, go, 0)
1596
1597        s.expire_all()
1598        u.name = "newname"
1599
1600        # can't predict result here
1601        # deterministically, depending on if
1602        # 'name' or 'addresses' is tested first
1603        mod = s.is_modified(u)
1604        addresses_loaded = "addresses" in u.__dict__
1605        assert mod is not addresses_loaded
1606
1607    def test_is_modified_syn(self):
1608        User, users = self.classes.User, self.tables.users
1609
1610        s = sessionmaker()()
1611
1612        mapper(User, users, properties={"uname": sa.orm.synonym("name")})
1613        u = User(uname="fred")
1614        assert s.is_modified(u)
1615        s.add(u)
1616        s.commit()
1617        assert not s.is_modified(u)
1618
1619
1620class DisposedStates(fixtures.MappedTest):
1621    run_setup_mappers = "once"
1622    run_inserts = "once"
1623    run_deletes = None
1624
1625    @classmethod
1626    def define_tables(cls, metadata):
1627        Table(
1628            "t1",
1629            metadata,
1630            Column(
1631                "id", Integer, primary_key=True, test_needs_autoincrement=True
1632            ),
1633            Column("data", String(50)),
1634        )
1635
1636    @classmethod
1637    def setup_classes(cls):
1638        class T(cls.Basic):
1639            def __init__(self, data):
1640                self.data = data
1641
1642        mapper(T, cls.tables.t1)
1643
1644    def teardown(self):
1645        from sqlalchemy.orm.session import _sessions
1646
1647        _sessions.clear()
1648        super(DisposedStates, self).teardown()
1649
1650    def _set_imap_in_disposal(self, sess, *objs):
1651        """remove selected objects from the given session, as though
1652        they were dereferenced and removed from WeakIdentityMap.
1653
1654        Hardcodes the identity map's "all_states()" method to return the
1655        full list of states.  This simulates the all_states() method
1656        returning results, afterwhich some of the states get garbage
1657        collected (this normally only happens during asynchronous gc).
1658        The Session now has one or more InstanceState's which have been
1659        removed from the identity map and disposed.
1660
1661        Will the Session not trip over this ???  Stay tuned.
1662
1663        """
1664
1665        all_states = sess.identity_map.all_states()
1666        sess.identity_map.all_states = lambda: all_states
1667        for obj in objs:
1668            state = attributes.instance_state(obj)
1669            sess.identity_map.discard(state)
1670            state._dispose()
1671
1672    def _test_session(self, **kwargs):
1673        T = self.classes.T
1674        sess = create_session(**kwargs)
1675
1676        data = o1, o2, o3, o4, o5 = [
1677            T("t1"),
1678            T("t2"),
1679            T("t3"),
1680            T("t4"),
1681            T("t5"),
1682        ]
1683
1684        sess.add_all(data)
1685
1686        sess.flush()
1687
1688        o1.data = "t1modified"
1689        o5.data = "t5modified"
1690
1691        self._set_imap_in_disposal(sess, o2, o4, o5)
1692        return sess
1693
1694    def test_flush(self):
1695        self._test_session().flush()
1696
1697    def test_clear(self):
1698        self._test_session().expunge_all()
1699
1700    def test_close(self):
1701        self._test_session().close()
1702
1703    def test_invalidate(self):
1704        self._test_session().invalidate()
1705
1706    def test_expunge_all(self):
1707        self._test_session().expunge_all()
1708
1709    def test_expire_all(self):
1710        self._test_session().expire_all()
1711
1712    def test_rollback(self):
1713        sess = self._test_session(autocommit=False, expire_on_commit=True)
1714        sess.commit()
1715
1716        sess.rollback()
1717
1718
1719class SessionInterface(fixtures.TestBase):
1720    """Bogus args to Session methods produce actionable exceptions."""
1721
1722    # TODO: expand with message body assertions.
1723
1724    _class_methods = set(("connection", "execute", "get_bind", "scalar"))
1725
1726    def _public_session_methods(self):
1727        Session = sa.orm.session.Session
1728
1729        blocklist = set(("begin", "query"))
1730
1731        ok = set()
1732        for meth in Session.public_methods:
1733            if meth in blocklist:
1734                continue
1735            spec = inspect_getfullargspec(getattr(Session, meth))
1736            if len(spec[0]) > 1 or spec[1]:
1737                ok.add(meth)
1738        return ok
1739
1740    def _map_it(self, cls):
1741        return mapper(
1742            cls,
1743            Table(
1744                "t",
1745                sa.MetaData(),
1746                Column(
1747                    "id",
1748                    Integer,
1749                    primary_key=True,
1750                    test_needs_autoincrement=True,
1751                ),
1752            ),
1753        )
1754
1755    def _test_instance_guards(self, user_arg):
1756        watchdog = set()
1757
1758        def x_raises_(obj, method, *args, **kw):
1759            watchdog.add(method)
1760            callable_ = getattr(obj, method)
1761            assert_raises(
1762                sa.orm.exc.UnmappedInstanceError, callable_, *args, **kw
1763            )
1764
1765        def raises_(method, *args, **kw):
1766            x_raises_(create_session(), method, *args, **kw)
1767
1768        raises_("__contains__", user_arg)
1769
1770        raises_("add", user_arg)
1771
1772        raises_("add_all", (user_arg,))
1773
1774        raises_("delete", user_arg)
1775
1776        raises_("expire", user_arg)
1777
1778        raises_("expunge", user_arg)
1779
1780        # flush will no-op without something in the unit of work
1781        def _():
1782            class OK(object):
1783                pass
1784
1785            self._map_it(OK)
1786
1787            s = create_session()
1788            s.add(OK())
1789            x_raises_(s, "flush", (user_arg,))
1790
1791        _()
1792
1793        raises_("is_modified", user_arg)
1794
1795        raises_("merge", user_arg)
1796
1797        raises_("refresh", user_arg)
1798
1799        instance_methods = (
1800            self._public_session_methods()
1801            - self._class_methods
1802            - set(
1803                [
1804                    "bulk_update_mappings",
1805                    "bulk_insert_mappings",
1806                    "bulk_save_objects",
1807                ]
1808            )
1809        )
1810
1811        eq_(
1812            watchdog,
1813            instance_methods,
1814            watchdog.symmetric_difference(instance_methods),
1815        )
1816
1817    def _test_class_guards(self, user_arg, is_class=True):
1818        watchdog = set()
1819
1820        def raises_(method, *args, **kw):
1821            watchdog.add(method)
1822            callable_ = getattr(create_session(), method)
1823            if is_class:
1824                assert_raises(
1825                    sa.orm.exc.UnmappedClassError, callable_, *args, **kw
1826                )
1827            else:
1828                assert_raises(
1829                    sa.exc.NoInspectionAvailable, callable_, *args, **kw
1830                )
1831
1832        raises_("connection", mapper=user_arg)
1833
1834        raises_("execute", "SELECT 1", mapper=user_arg)
1835
1836        raises_("get_bind", mapper=user_arg)
1837
1838        raises_("scalar", "SELECT 1", mapper=user_arg)
1839
1840        eq_(
1841            watchdog,
1842            self._class_methods,
1843            watchdog.symmetric_difference(self._class_methods),
1844        )
1845
1846    def test_unmapped_instance(self):
1847        class Unmapped(object):
1848            pass
1849
1850        self._test_instance_guards(Unmapped())
1851        self._test_class_guards(Unmapped)
1852
1853    def test_unmapped_primitives(self):
1854        for prim in ("doh", 123, ("t", "u", "p", "l", "e")):
1855            self._test_instance_guards(prim)
1856            self._test_class_guards(prim, is_class=False)
1857
1858    def test_unmapped_class_for_instance(self):
1859        class Unmapped(object):
1860            pass
1861
1862        self._test_instance_guards(Unmapped)
1863        self._test_class_guards(Unmapped)
1864
1865    def test_mapped_class_for_instance(self):
1866        class Mapped(object):
1867            pass
1868
1869        self._map_it(Mapped)
1870
1871        self._test_instance_guards(Mapped)
1872        # no class guards- it would pass.
1873
1874    def test_missing_state(self):
1875        class Mapped(object):
1876            pass
1877
1878        early = Mapped()
1879        self._map_it(Mapped)
1880
1881        self._test_instance_guards(early)
1882        self._test_class_guards(early, is_class=False)
1883
1884    def test_refresh_arg_signature(self):
1885        class Mapped(object):
1886            pass
1887
1888        self._map_it(Mapped)
1889
1890        m1 = Mapped()
1891        s = create_session()
1892
1893        with mock.patch.object(s, "_validate_persistent"):
1894            assert_raises_message(
1895                sa.exc.ArgumentError,
1896                "with_for_update should be the boolean value True, "
1897                "or a dictionary with options",
1898                s.refresh,
1899                m1,
1900                with_for_update={},
1901            )
1902
1903            with mock.patch(
1904                "sqlalchemy.orm.session.loading.load_on_ident"
1905            ) as load_on_ident:
1906                s.refresh(m1, with_for_update={"read": True})
1907                s.refresh(m1, with_for_update=True)
1908                s.refresh(m1, with_for_update=False)
1909                s.refresh(m1)
1910
1911            from sqlalchemy.orm.query import LockmodeArg
1912
1913            eq_(
1914                [
1915                    call[-1]["with_for_update"]
1916                    for call in load_on_ident.mock_calls
1917                ],
1918                [LockmodeArg(read=True), LockmodeArg(), None, None],
1919            )
1920
1921
1922class FlushWarningsTest(fixtures.MappedTest):
1923    run_setup_mappers = "each"
1924
1925    @classmethod
1926    def define_tables(cls, metadata):
1927        Table(
1928            "user",
1929            metadata,
1930            Column(
1931                "id", Integer, primary_key=True, test_needs_autoincrement=True
1932            ),
1933            Column("name", String(20)),
1934        )
1935
1936        Table(
1937            "address",
1938            metadata,
1939            Column(
1940                "id", Integer, primary_key=True, test_needs_autoincrement=True
1941            ),
1942            Column("user_id", Integer, ForeignKey("user.id")),
1943            Column("email", String(20)),
1944        )
1945
1946    @classmethod
1947    def setup_classes(cls):
1948        class User(cls.Basic):
1949            pass
1950
1951        class Address(cls.Basic):
1952            pass
1953
1954    @classmethod
1955    def setup_mappers(cls):
1956        user, User = cls.tables.user, cls.classes.User
1957        address, Address = cls.tables.address, cls.classes.Address
1958        mapper(
1959            User,
1960            user,
1961            properties={"addresses": relationship(Address, backref="user")},
1962        )
1963        mapper(Address, address)
1964
1965    def test_o2m_cascade_add(self):
1966        Address = self.classes.Address
1967
1968        def evt(mapper, conn, instance):
1969            instance.addresses.append(Address(email="x1"))
1970
1971        self._test(evt, "collection append")
1972
1973    def test_o2m_cascade_remove(self):
1974        def evt(mapper, conn, instance):
1975            del instance.addresses[0]
1976
1977        self._test(evt, "collection remove")
1978
1979    def test_m2o_cascade_add(self):
1980        User = self.classes.User
1981
1982        def evt(mapper, conn, instance):
1983            instance.addresses[0].user = User(name="u2")
1984
1985        self._test(evt, "related attribute set")
1986
1987    def test_m2o_cascade_remove(self):
1988        def evt(mapper, conn, instance):
1989            a1 = instance.addresses[0]
1990            del a1.user
1991
1992        self._test(evt, "related attribute delete")
1993
1994    def test_plain_add(self):
1995        Address = self.classes.Address
1996
1997        def evt(mapper, conn, instance):
1998            object_session(instance).add(Address(email="x1"))
1999
2000        self._test(evt, r"Session.add\(\)")
2001
2002    def test_plain_merge(self):
2003        Address = self.classes.Address
2004
2005        def evt(mapper, conn, instance):
2006            object_session(instance).merge(Address(email="x1"))
2007
2008        self._test(evt, r"Session.merge\(\)")
2009
2010    def test_plain_delete(self):
2011        Address = self.classes.Address
2012
2013        def evt(mapper, conn, instance):
2014            object_session(instance).delete(Address(email="x1"))
2015
2016        self._test(evt, r"Session.delete\(\)")
2017
2018    def _test(self, fn, method):
2019        User = self.classes.User
2020        Address = self.classes.Address
2021
2022        s = Session()
2023        event.listen(User, "after_insert", fn)
2024
2025        u1 = User(name="u1", addresses=[Address(name="a1")])
2026        s.add(u1)
2027        assert_raises_message(
2028            sa.exc.SAWarning, "Usage of the '%s'" % method, s.commit
2029        )
2030