1import sqlalchemy as sa
2from sqlalchemy import event
3from sqlalchemy import ForeignKey
4from sqlalchemy import Integer
5from sqlalchemy import String
6from sqlalchemy import testing
7from sqlalchemy.ext.declarative import declarative_base
8from sqlalchemy.orm import attributes
9from sqlalchemy.orm import class_mapper
10from sqlalchemy.orm import configure_mappers
11from sqlalchemy.orm import create_session
12from sqlalchemy.orm import deferred
13from sqlalchemy.orm import events
14from sqlalchemy.orm import EXT_SKIP
15from sqlalchemy.orm import instrumentation
16from sqlalchemy.orm import Mapper
17from sqlalchemy.orm import mapper
18from sqlalchemy.orm import query
19from sqlalchemy.orm import relationship
20from sqlalchemy.orm import Session
21from sqlalchemy.orm import sessionmaker
22from sqlalchemy.orm.mapper import _mapper_registry
23from sqlalchemy.testing import assert_raises
24from sqlalchemy.testing import assert_raises_message
25from sqlalchemy.testing import AssertsCompiledSQL
26from sqlalchemy.testing import eq_
27from sqlalchemy.testing import expect_warnings
28from sqlalchemy.testing import fixtures
29from sqlalchemy.testing import is_not
30from sqlalchemy.testing.assertsql import CompiledSQL
31from sqlalchemy.testing.mock import ANY
32from sqlalchemy.testing.mock import call
33from sqlalchemy.testing.mock import Mock
34from sqlalchemy.testing.schema import Column
35from sqlalchemy.testing.schema import Table
36from sqlalchemy.testing.util import gc_collect
37from test.orm import _fixtures
38
39
40class _RemoveListeners(object):
41    def teardown(self):
42        events.MapperEvents._clear()
43        events.InstanceEvents._clear()
44        events.SessionEvents._clear()
45        events.InstrumentationEvents._clear()
46        events.QueryEvents._clear()
47        super(_RemoveListeners, self).teardown()
48
49
50class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest):
51    run_inserts = None
52
53    @classmethod
54    def define_tables(cls, metadata):
55        super(MapperEventsTest, cls).define_tables(metadata)
56        metadata.tables["users"].append_column(
57            Column("extra", Integer, default=5, onupdate=10)
58        )
59
60    def test_instance_event_listen(self):
61        """test listen targets for instance events"""
62
63        users, addresses = self.tables.users, self.tables.addresses
64
65        canary = []
66
67        class A(object):
68            pass
69
70        class B(A):
71            pass
72
73        mapper(A, users)
74        mapper(
75            B, addresses, inherits=A, properties={"address_id": addresses.c.id}
76        )
77
78        def init_a(target, args, kwargs):
79            canary.append(("init_a", target))
80
81        def init_b(target, args, kwargs):
82            canary.append(("init_b", target))
83
84        def init_c(target, args, kwargs):
85            canary.append(("init_c", target))
86
87        def init_d(target, args, kwargs):
88            canary.append(("init_d", target))
89
90        def init_e(target, args, kwargs):
91            canary.append(("init_e", target))
92
93        event.listen(mapper, "init", init_a)
94        event.listen(Mapper, "init", init_b)
95        event.listen(class_mapper(A), "init", init_c)
96        event.listen(A, "init", init_d)
97        event.listen(A, "init", init_e, propagate=True)
98
99        a = A()
100        eq_(
101            canary,
102            [
103                ("init_a", a),
104                ("init_b", a),
105                ("init_c", a),
106                ("init_d", a),
107                ("init_e", a),
108            ],
109        )
110
111        # test propagate flag
112        canary[:] = []
113        b = B()
114        eq_(canary, [("init_a", b), ("init_b", b), ("init_e", b)])
115
116    def listen_all(self, mapper, **kw):
117        canary = []
118
119        def evt(meth):
120            def go(*args, **kwargs):
121                canary.append(meth)
122
123            return go
124
125        for meth in [
126            "init",
127            "init_failure",
128            "load",
129            "refresh",
130            "refresh_flush",
131            "expire",
132            "before_insert",
133            "after_insert",
134            "before_update",
135            "after_update",
136            "before_delete",
137            "after_delete",
138        ]:
139            event.listen(mapper, meth, evt(meth), **kw)
140        return canary
141
142    def test_init_allow_kw_modify(self):
143        User, users = self.classes.User, self.tables.users
144        mapper(User, users)
145
146        @event.listens_for(User, "init")
147        def add_name(obj, args, kwargs):
148            kwargs["name"] = "ed"
149
150        u1 = User()
151        eq_(u1.name, "ed")
152
153    def test_init_failure_hook(self):
154        users = self.tables.users
155
156        class Thing(object):
157            def __init__(self, **kw):
158                if kw.get("fail"):
159                    raise Exception("failure")
160
161        mapper(Thing, users)
162
163        canary = Mock()
164        event.listen(Thing, "init_failure", canary)
165
166        Thing()
167        eq_(canary.mock_calls, [])
168
169        assert_raises_message(Exception, "failure", Thing, fail=True)
170        eq_(canary.mock_calls, [call(ANY, (), {"fail": True})])
171
172    def test_listen_doesnt_force_compile(self):
173        User, users = self.classes.User, self.tables.users
174        m = mapper(
175            User,
176            users,
177            properties={
178                # intentionally non-existent class to ensure
179                # the lambda is not called, simulates a class from
180                # a not-yet-imported module
181                "addresses": relationship(lambda: ImNotAClass)  # noqa
182            },
183        )
184        event.listen(User, "before_insert", lambda *a, **kw: None)
185        assert not m.configured
186
187    def test_basic(self):
188        User, users = self.classes.User, self.tables.users
189
190        mapper(User, users)
191        canary = self.listen_all(User)
192        named_canary = self.listen_all(User, named=True)
193
194        sess = create_session()
195        u = User(name="u1")
196        sess.add(u)
197        sess.flush()
198        sess.expire(u)
199        u = sess.query(User).get(u.id)
200        sess.expunge_all()
201        u = sess.query(User).get(u.id)
202        u.name = "u1 changed"
203        sess.flush()
204        sess.delete(u)
205        sess.flush()
206        expected = [
207            "init",
208            "before_insert",
209            "refresh_flush",
210            "after_insert",
211            "expire",
212            "refresh",
213            "load",
214            "before_update",
215            "refresh_flush",
216            "after_update",
217            "before_delete",
218            "after_delete",
219        ]
220        eq_(canary, expected)
221        eq_(named_canary, expected)
222
223    def test_insert_before_configured(self):
224        users, User = self.tables.users, self.classes.User
225
226        mapper(User, users)
227
228        canary = Mock()
229
230        event.listen(mapper, "before_configured", canary.listen1)
231        event.listen(mapper, "before_configured", canary.listen2, insert=True)
232        event.listen(mapper, "before_configured", canary.listen3)
233        event.listen(mapper, "before_configured", canary.listen4, insert=True)
234
235        configure_mappers()
236
237        eq_(
238            canary.mock_calls,
239            [call.listen4(), call.listen2(), call.listen1(), call.listen3()],
240        )
241
242    def test_insert_flags(self):
243        users, User = self.tables.users, self.classes.User
244
245        m = mapper(User, users)
246
247        canary = Mock()
248
249        arg = Mock()
250
251        event.listen(m, "before_insert", canary.listen1)
252        event.listen(m, "before_insert", canary.listen2, insert=True)
253        event.listen(
254            m, "before_insert", canary.listen3, propagate=True, insert=True
255        )
256        event.listen(m, "load", canary.listen4)
257        event.listen(m, "load", canary.listen5, insert=True)
258        event.listen(m, "load", canary.listen6, propagate=True, insert=True)
259
260        User()
261        m.dispatch.before_insert(arg, arg, arg)
262        m.class_manager.dispatch.load(arg, arg)
263        eq_(
264            canary.mock_calls,
265            [
266                call.listen3(arg, arg, arg.obj()),
267                call.listen2(arg, arg, arg.obj()),
268                call.listen1(arg, arg, arg.obj()),
269                call.listen6(arg.obj(), arg),
270                call.listen5(arg.obj(), arg),
271                call.listen4(arg.obj(), arg),
272            ],
273        )
274
275    def test_merge(self):
276        users, User = self.tables.users, self.classes.User
277
278        mapper(User, users)
279
280        canary = []
281
282        def load(obj, ctx):
283            canary.append("load")
284
285        event.listen(mapper, "load", load)
286
287        s = Session()
288        u = User(name="u1")
289        s.add(u)
290        s.commit()
291        s = Session()
292        u2 = s.merge(u)
293        s = Session()
294        u2 = s.merge(User(name="u2"))  # noqa
295        s.commit()
296        s.query(User).order_by(User.id).first()
297        eq_(canary, ["load", "load", "load"])
298
299    def test_inheritance(self):
300        users, addresses, User = (
301            self.tables.users,
302            self.tables.addresses,
303            self.classes.User,
304        )
305
306        class AdminUser(User):
307            pass
308
309        mapper(User, users)
310        mapper(
311            AdminUser,
312            addresses,
313            inherits=User,
314            properties={"address_id": addresses.c.id},
315        )
316
317        canary1 = self.listen_all(User, propagate=True)
318        canary2 = self.listen_all(User)
319        canary3 = self.listen_all(AdminUser)
320
321        sess = create_session()
322        am = AdminUser(name="au1", email_address="au1@e1")
323        sess.add(am)
324        sess.flush()
325        am = sess.query(AdminUser).populate_existing().get(am.id)
326        sess.expunge_all()
327        am = sess.query(AdminUser).get(am.id)
328        am.name = "au1 changed"
329        sess.flush()
330        sess.delete(am)
331        sess.flush()
332        eq_(
333            canary1,
334            [
335                "init",
336                "before_insert",
337                "refresh_flush",
338                "after_insert",
339                "refresh",
340                "load",
341                "before_update",
342                "refresh_flush",
343                "after_update",
344                "before_delete",
345                "after_delete",
346            ],
347        )
348        eq_(canary2, [])
349        eq_(
350            canary3,
351            [
352                "init",
353                "before_insert",
354                "refresh_flush",
355                "after_insert",
356                "refresh",
357                "load",
358                "before_update",
359                "refresh_flush",
360                "after_update",
361                "before_delete",
362                "after_delete",
363            ],
364        )
365
366    def test_inheritance_subclass_deferred(self):
367        users, addresses, User = (
368            self.tables.users,
369            self.tables.addresses,
370            self.classes.User,
371        )
372
373        mapper(User, users)
374
375        canary1 = self.listen_all(User, propagate=True)
376        canary2 = self.listen_all(User)
377
378        class AdminUser(User):
379            pass
380
381        mapper(
382            AdminUser,
383            addresses,
384            inherits=User,
385            properties={"address_id": addresses.c.id},
386        )
387        canary3 = self.listen_all(AdminUser)
388
389        sess = create_session()
390        am = AdminUser(name="au1", email_address="au1@e1")
391        sess.add(am)
392        sess.flush()
393        am = sess.query(AdminUser).populate_existing().get(am.id)
394        sess.expunge_all()
395        am = sess.query(AdminUser).get(am.id)
396        am.name = "au1 changed"
397        sess.flush()
398        sess.delete(am)
399        sess.flush()
400        eq_(
401            canary1,
402            [
403                "init",
404                "before_insert",
405                "refresh_flush",
406                "after_insert",
407                "refresh",
408                "load",
409                "before_update",
410                "refresh_flush",
411                "after_update",
412                "before_delete",
413                "after_delete",
414            ],
415        )
416        eq_(canary2, [])
417        eq_(
418            canary3,
419            [
420                "init",
421                "before_insert",
422                "refresh_flush",
423                "after_insert",
424                "refresh",
425                "load",
426                "before_update",
427                "refresh_flush",
428                "after_update",
429                "before_delete",
430                "after_delete",
431            ],
432        )
433
434    def test_before_after_only_collection(self):
435        """before_update is called on parent for collection modifications,
436        after_update is called even if no columns were updated.
437
438        """
439
440        keywords, items, item_keywords, Keyword, Item = (
441            self.tables.keywords,
442            self.tables.items,
443            self.tables.item_keywords,
444            self.classes.Keyword,
445            self.classes.Item,
446        )
447
448        mapper(
449            Item,
450            items,
451            properties={
452                "keywords": relationship(Keyword, secondary=item_keywords)
453            },
454        )
455        mapper(Keyword, keywords)
456
457        canary1 = self.listen_all(Item)
458        canary2 = self.listen_all(Keyword)
459
460        sess = create_session()
461        i1 = Item(description="i1")
462        k1 = Keyword(name="k1")
463        sess.add(i1)
464        sess.add(k1)
465        sess.flush()
466        eq_(canary1, ["init", "before_insert", "after_insert"])
467        eq_(canary2, ["init", "before_insert", "after_insert"])
468
469        canary1[:] = []
470        canary2[:] = []
471
472        i1.keywords.append(k1)
473        sess.flush()
474        eq_(canary1, ["before_update", "after_update"])
475        eq_(canary2, [])
476
477    def test_before_after_configured_warn_on_non_mapper(self):
478        User, users = self.classes.User, self.tables.users
479
480        m1 = Mock()
481
482        mapper(User, users)
483        assert_raises_message(
484            sa.exc.SAWarning,
485            r"before_configured' and 'after_configured' ORM events only "
486            r"invoke with the mapper\(\) function or Mapper class as "
487            r"the target.",
488            event.listen,
489            User,
490            "before_configured",
491            m1,
492        )
493
494        assert_raises_message(
495            sa.exc.SAWarning,
496            r"before_configured' and 'after_configured' ORM events only "
497            r"invoke with the mapper\(\) function or Mapper class as "
498            r"the target.",
499            event.listen,
500            User,
501            "after_configured",
502            m1,
503        )
504
505    def test_before_after_configured(self):
506        User, users = self.classes.User, self.tables.users
507
508        m1 = Mock()
509        m2 = Mock()
510
511        mapper(User, users)
512
513        event.listen(mapper, "before_configured", m1)
514        event.listen(mapper, "after_configured", m2)
515
516        s = Session()
517        s.query(User)
518
519        eq_(m1.mock_calls, [call()])
520        eq_(m2.mock_calls, [call()])
521
522    def test_instrument_event(self):
523        Address, addresses, users, User = (
524            self.classes.Address,
525            self.tables.addresses,
526            self.tables.users,
527            self.classes.User,
528        )
529
530        canary = []
531
532        def instrument_class(mapper, cls):
533            canary.append(cls)
534
535        event.listen(Mapper, "instrument_class", instrument_class)
536
537        mapper(User, users)
538        eq_(canary, [User])
539        mapper(Address, addresses)
540        eq_(canary, [User, Address])
541
542    def test_instrument_class_precedes_class_instrumentation(self):
543        users = self.tables.users
544
545        class MyClass(object):
546            pass
547
548        canary = Mock()
549
550        def my_init(self):
551            canary.init()
552
553        # mapper level event
554        @event.listens_for(mapper, "instrument_class")
555        def instrument_class(mp, class_):
556            canary.instrument_class(class_)
557            class_.__init__ = my_init
558
559        # instrumentationmanager event
560        @event.listens_for(object, "class_instrument")
561        def class_instrument(class_):
562            canary.class_instrument(class_)
563
564        mapper(MyClass, users)
565
566        m1 = MyClass()
567        assert attributes.instance_state(m1)
568
569        eq_(
570            [
571                call.instrument_class(MyClass),
572                call.class_instrument(MyClass),
573                call.init(),
574            ],
575            canary.mock_calls,
576        )
577
578    def test_before_mapper_configured_event(self):
579        """Test [ticket:4397].
580
581        This event is intended to allow a specific mapper to be skipped during
582        the configure step, by returning a value of
583        :attr:`.orm.interfaces.EXT_SKIP` which means the mapper will be skipped
584        within this configure run.    The "new mappers" flag will remain set in
585        this case and the configure operation will occur again.
586
587        This event, and its return value, make it possible to query one base
588        while a different one still needs configuration, which cannot be
589        completed at this time.
590        """
591
592        User, users = self.classes.User, self.tables.users
593        mapper(User, users)
594
595        AnotherBase = declarative_base()
596
597        class Animal(AnotherBase):
598            __tablename__ = "animal"
599            species = Column(String(30), primary_key=True)
600            __mapper_args__ = dict(
601                polymorphic_on="species", polymorphic_identity="Animal"
602            )
603
604        # Register the first classes and create their Mappers:
605        configure_mappers()
606
607        unconfigured = [m for m in _mapper_registry if not m.configured]
608        eq_(0, len(unconfigured))
609
610        # Declare a subclass, table and mapper, which refers to one that has
611        # not been loaded yet (Employer), and therefore cannot be configured:
612        class Mammal(Animal):
613            nonexistent = relationship("Nonexistent")
614
615        # These new classes should not be configured at this point:
616        unconfigured = [m for m in _mapper_registry if not m.configured]
617        eq_(1, len(unconfigured))
618
619        # Now try to query User, which is internally consistent. This query
620        # fails by default because Mammal needs to be configured, and cannot
621        # be:
622        def probe():
623            s = Session()
624            s.query(User)
625
626        assert_raises(sa.exc.InvalidRequestError, probe)
627
628        # If we disable configuring mappers while querying, then it succeeds:
629        @event.listens_for(
630            AnotherBase,
631            "before_mapper_configured",
632            propagate=True,
633            retval=True,
634        )
635        def disable_configure_mappers(mapper, cls):
636            return EXT_SKIP
637
638        probe()
639
640
641class RestoreLoadContextTest(fixtures.DeclarativeMappedTest):
642    @classmethod
643    def setup_classes(cls):
644        class A(cls.DeclarativeBasic):
645            __tablename__ = "a"
646            id = Column(Integer, primary_key=True)
647            unloaded = deferred(Column(String(50)))
648            bs = relationship("B", lazy="joined")
649
650        class B(cls.DeclarativeBasic):
651            __tablename__ = "b"
652            id = Column(Integer, primary_key=True)
653            a_id = Column(ForeignKey("a.id"))
654
655    @classmethod
656    def insert_data(cls, connection):
657        A, B = cls.classes("A", "B")
658        s = Session(connection)
659        s.add(A(bs=[B(), B(), B()]))
660        s.commit()
661
662    def _combinations(fn):
663        return testing.combinations(
664            (lambda A: A, "load", lambda instance, context: instance.unloaded),
665            (
666                lambda A: A,
667                "refresh",
668                lambda instance, context, attrs: instance.unloaded,
669            ),
670            (
671                lambda session: session,
672                "loaded_as_persistent",
673                lambda session, instance: instance.unloaded
674                if instance.__class__.__name__ == "A"
675                else None,
676            ),
677            argnames="target, event_name, fn",
678        )(fn)
679
680    def teardown(self):
681        A = self.classes.A
682        A._sa_class_manager.dispatch._clear()
683
684    @_combinations
685    def test_warning(self, target, event_name, fn):
686        A = self.classes.A
687        s = Session()
688        target = testing.util.resolve_lambda(target, A=A, session=s)
689        event.listen(target, event_name, fn)
690
691        with expect_warnings(
692            r"Loading context for \<A at .*\> has changed within a "
693            r"load/refresh handler, suggesting a row refresh operation "
694            r"took place. "
695            r"If this event handler is expected to be emitting row refresh "
696            r"operations within an existing load or refresh operation, set "
697            r"restore_load_context=True when establishing the listener to "
698            r"ensure the context remains unchanged when the event handler "
699            r"completes."
700        ):
701            a1 = s.query(A).all()[0]
702            if event_name == "refresh":
703                s.refresh(a1)
704        # joined eager load didn't continue
705        eq_(len(a1.bs), 1)
706
707    @_combinations
708    def test_flag_resolves_existing(self, target, event_name, fn):
709        A = self.classes.A
710        s = Session()
711        target = testing.util.resolve_lambda(target, A=A, session=s)
712
713        a1 = s.query(A).all()[0]
714
715        s.expire(a1)
716        event.listen(target, event_name, fn, restore_load_context=True)
717        s.query(A).all()
718
719    @testing.combinations(
720        ("load", lambda instance, context: instance.unloaded),
721        (
722            "refresh",
723            lambda instance, context, attrs: instance.unloaded,
724        ),
725    )
726    def test_flag_resolves_existing_for_subclass(self, event_name, fn):
727        Base = declarative_base()
728
729        event.listen(
730            Base, event_name, fn, propagate=True, restore_load_context=True
731        )
732
733        class A(Base):
734            __tablename__ = "a"
735            id = Column(Integer, primary_key=True)
736            unloaded = deferred(Column(String(50)))
737
738        s = Session(testing.db)
739
740        a1 = s.query(A).all()[0]
741        if event_name == "refresh":
742            s.refresh(a1)
743        s.close()
744
745    @_combinations
746    def test_flag_resolves(self, target, event_name, fn):
747        A = self.classes.A
748        s = Session()
749        target = testing.util.resolve_lambda(target, A=A, session=s)
750        event.listen(target, event_name, fn, restore_load_context=True)
751
752        a1 = s.query(A).all()[0]
753        if event_name == "refresh":
754            s.refresh(a1)
755        # joined eager load continued
756        eq_(len(a1.bs), 3)
757
758
759class DeclarativeEventListenTest(
760    _RemoveListeners, fixtures.DeclarativeMappedTest
761):
762    run_setup_classes = "each"
763    run_deletes = None
764
765    def test_inheritance_propagate_after_config(self):
766        # test [ticket:2949]
767
768        class A(self.DeclarativeBasic):
769            __tablename__ = "a"
770            id = Column(Integer, primary_key=True)
771
772        class B(A):
773            pass
774
775        listen = Mock()
776        event.listen(self.DeclarativeBasic, "load", listen, propagate=True)
777
778        class C(B):
779            pass
780
781        m1 = A.__mapper__.class_manager
782        m2 = B.__mapper__.class_manager
783        m3 = C.__mapper__.class_manager
784        a1 = A()
785        b1 = B()
786        c1 = C()
787        m3.dispatch.load(c1._sa_instance_state, "c")
788        m2.dispatch.load(b1._sa_instance_state, "b")
789        m1.dispatch.load(a1._sa_instance_state, "a")
790        eq_(listen.mock_calls, [call(c1, "c"), call(b1, "b"), call(a1, "a")])
791
792
793class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest):
794
795    """ "test event listeners against unmapped classes.
796
797    This incurs special logic.  Note if we ever do the "remove" case,
798    it has to get all of these, too.
799
800    """
801
802    run_inserts = None
803
804    def test_deferred_map_event(self):
805        """
806        1. mapper event listen on class
807        2. map class
808        3. event fire should receive event
809
810        """
811        users, User = (self.tables.users, self.classes.User)
812
813        canary = []
814
815        def evt(x, y, z):
816            canary.append(x)
817
818        event.listen(User, "before_insert", evt, raw=True)
819
820        m = mapper(User, users)
821        m.dispatch.before_insert(5, 6, 7)
822        eq_(canary, [5])
823
824    def test_deferred_map_event_subclass_propagate(self):
825        """
826        1. mapper event listen on class, w propagate
827        2. map only subclass of class
828        3. event fire should receive event
829
830        """
831        users, User = (self.tables.users, self.classes.User)
832
833        class SubUser(User):
834            pass
835
836        class SubSubUser(SubUser):
837            pass
838
839        canary = Mock()
840
841        def evt(x, y, z):
842            canary.append(x)
843
844        event.listen(User, "before_insert", canary, propagate=True, raw=True)
845
846        m = mapper(SubUser, users)
847        m.dispatch.before_insert(5, 6, 7)
848        eq_(canary.mock_calls, [call(5, 6, 7)])
849
850        m2 = mapper(SubSubUser, users)
851
852        m2.dispatch.before_insert(8, 9, 10)
853        eq_(canary.mock_calls, [call(5, 6, 7), call(8, 9, 10)])
854
855    def test_deferred_map_event_subclass_no_propagate(self):
856        """
857        1. mapper event listen on class, w/o propagate
858        2. map only subclass of class
859        3. event fire should not receive event
860
861        """
862        users, User = (self.tables.users, self.classes.User)
863
864        class SubUser(User):
865            pass
866
867        canary = []
868
869        def evt(x, y, z):
870            canary.append(x)
871
872        event.listen(User, "before_insert", evt, propagate=False)
873
874        m = mapper(SubUser, users)
875        m.dispatch.before_insert(5, 6, 7)
876        eq_(canary, [])
877
878    def test_deferred_map_event_subclass_post_mapping_propagate(self):
879        """
880        1. map only subclass of class
881        2. mapper event listen on class, w propagate
882        3. event fire should receive event
883
884        """
885        users, User = (self.tables.users, self.classes.User)
886
887        class SubUser(User):
888            pass
889
890        m = mapper(SubUser, users)
891
892        canary = []
893
894        def evt(x, y, z):
895            canary.append(x)
896
897        event.listen(User, "before_insert", evt, propagate=True, raw=True)
898
899        m.dispatch.before_insert(5, 6, 7)
900        eq_(canary, [5])
901
902    def test_deferred_map_event_subclass_post_mapping_propagate_two(self):
903        """
904        1. map only subclass of class
905        2. mapper event listen on class, w propagate
906        3. event fire should receive event
907
908        """
909        users, User = (self.tables.users, self.classes.User)
910
911        class SubUser(User):
912            pass
913
914        class SubSubUser(SubUser):
915            pass
916
917        m = mapper(SubUser, users)
918
919        canary = Mock()
920        event.listen(User, "before_insert", canary, propagate=True, raw=True)
921
922        m2 = mapper(SubSubUser, users)
923
924        m.dispatch.before_insert(5, 6, 7)
925        eq_(canary.mock_calls, [call(5, 6, 7)])
926
927        m2.dispatch.before_insert(8, 9, 10)
928        eq_(canary.mock_calls, [call(5, 6, 7), call(8, 9, 10)])
929
930    def test_deferred_instance_event_subclass_post_mapping_propagate(self):
931        """
932        1. map only subclass of class
933        2. instance event listen on class, w propagate
934        3. event fire should receive event
935
936        """
937        users, User = (self.tables.users, self.classes.User)
938
939        class SubUser(User):
940            pass
941
942        m = mapper(SubUser, users)
943
944        canary = []
945
946        def evt(x):
947            canary.append(x)
948
949        event.listen(User, "load", evt, propagate=True, raw=True)
950
951        m.class_manager.dispatch.load(5)
952        eq_(canary, [5])
953
954    def test_deferred_instance_event_plain(self):
955        """
956        1. instance event listen on class, w/o propagate
957        2. map class
958        3. event fire should receive event
959
960        """
961        users, User = (self.tables.users, self.classes.User)
962
963        canary = []
964
965        def evt(x):
966            canary.append(x)
967
968        event.listen(User, "load", evt, raw=True)
969
970        m = mapper(User, users)
971        m.class_manager.dispatch.load(5)
972        eq_(canary, [5])
973
974    def test_deferred_instance_event_subclass_propagate_subclass_only(self):
975        """
976        1. instance event listen on class, w propagate
977        2. map two subclasses of class
978        3. event fire on each class should receive one and only one event
979
980        """
981        users, User = (self.tables.users, self.classes.User)
982
983        class SubUser(User):
984            pass
985
986        class SubUser2(User):
987            pass
988
989        canary = []
990
991        def evt(x):
992            canary.append(x)
993
994        event.listen(User, "load", evt, propagate=True, raw=True)
995
996        m = mapper(SubUser, users)
997        m2 = mapper(SubUser2, users)
998
999        m.class_manager.dispatch.load(5)
1000        eq_(canary, [5])
1001
1002        m2.class_manager.dispatch.load(5)
1003        eq_(canary, [5, 5])
1004
1005    def test_deferred_instance_event_subclass_propagate_baseclass(self):
1006        """
1007        1. instance event listen on class, w propagate
1008        2. map one subclass of class, map base class, leave 2nd subclass
1009           unmapped
1010        3. event fire on sub should receive one and only one event
1011        4. event fire on base should receive one and only one event
1012        5. map 2nd subclass
1013        6. event fire on 2nd subclass should receive one and only one event
1014        """
1015        users, User = (self.tables.users, self.classes.User)
1016
1017        class SubUser(User):
1018            pass
1019
1020        class SubUser2(User):
1021            pass
1022
1023        canary = Mock()
1024        event.listen(User, "load", canary, propagate=True, raw=False)
1025
1026        # reversing these fixes....
1027        m = mapper(SubUser, users)
1028        m2 = mapper(User, users)
1029
1030        instance = Mock()
1031        m.class_manager.dispatch.load(instance)
1032
1033        eq_(canary.mock_calls, [call(instance.obj())])
1034
1035        m2.class_manager.dispatch.load(instance)
1036        eq_(canary.mock_calls, [call(instance.obj()), call(instance.obj())])
1037
1038        m3 = mapper(SubUser2, users)
1039        m3.class_manager.dispatch.load(instance)
1040        eq_(
1041            canary.mock_calls,
1042            [call(instance.obj()), call(instance.obj()), call(instance.obj())],
1043        )
1044
1045    def test_deferred_instance_event_subclass_no_propagate(self):
1046        """
1047        1. instance event listen on class, w/o propagate
1048        2. map subclass
1049        3. event fire on subclass should not receive event
1050        """
1051        users, User = (self.tables.users, self.classes.User)
1052
1053        class SubUser(User):
1054            pass
1055
1056        canary = []
1057
1058        def evt(x):
1059            canary.append(x)
1060
1061        event.listen(User, "load", evt, propagate=False)
1062
1063        m = mapper(SubUser, users)
1064        m.class_manager.dispatch.load(5)
1065        eq_(canary, [])
1066
1067    def test_deferred_instrument_event(self):
1068        User = self.classes.User
1069
1070        canary = []
1071
1072        def evt(x):
1073            canary.append(x)
1074
1075        event.listen(User, "attribute_instrument", evt)
1076
1077        instrumentation._instrumentation_factory.dispatch.attribute_instrument(
1078            User
1079        )
1080        eq_(canary, [User])
1081
1082    def test_isolation_instrument_event(self):
1083        User = self.classes.User
1084
1085        class Bar(object):
1086            pass
1087
1088        canary = []
1089
1090        def evt(x):
1091            canary.append(x)
1092
1093        event.listen(Bar, "attribute_instrument", evt)
1094
1095        instrumentation._instrumentation_factory.dispatch.attribute_instrument(
1096            User
1097        )
1098        eq_(canary, [])
1099
1100    @testing.requires.predictable_gc
1101    def test_instrument_event_auto_remove(self):
1102        class Bar(object):
1103            pass
1104
1105        dispatch = instrumentation._instrumentation_factory.dispatch
1106        assert not dispatch.attribute_instrument
1107
1108        event.listen(Bar, "attribute_instrument", lambda: None)
1109
1110        eq_(len(dispatch.attribute_instrument), 1)
1111
1112        del Bar
1113        gc_collect()
1114
1115        assert not dispatch.attribute_instrument
1116
1117    def test_deferred_instrument_event_subclass_propagate(self):
1118        User = self.classes.User
1119
1120        class SubUser(User):
1121            pass
1122
1123        canary = []
1124
1125        def evt(x):
1126            canary.append(x)
1127
1128        event.listen(User, "attribute_instrument", evt, propagate=True)
1129
1130        instrumentation._instrumentation_factory.dispatch.attribute_instrument(
1131            SubUser
1132        )
1133        eq_(canary, [SubUser])
1134
1135    def test_deferred_instrument_event_subclass_no_propagate(self):
1136        users, User = (self.tables.users, self.classes.User)
1137
1138        class SubUser(User):
1139            pass
1140
1141        canary = []
1142
1143        def evt(x):
1144            canary.append(x)
1145
1146        event.listen(User, "attribute_instrument", evt, propagate=False)
1147
1148        mapper(SubUser, users)
1149        instrumentation._instrumentation_factory.dispatch.attribute_instrument(
1150            5
1151        )
1152        eq_(canary, [])
1153
1154
1155class LoadTest(_fixtures.FixtureTest):
1156    run_inserts = None
1157
1158    @classmethod
1159    def setup_mappers(cls):
1160        User, users = cls.classes.User, cls.tables.users
1161
1162        mapper(User, users)
1163
1164    def _fixture(self):
1165        User = self.classes.User
1166
1167        canary = []
1168
1169        def load(target, ctx):
1170            canary.append("load")
1171
1172        def refresh(target, ctx, attrs):
1173            canary.append(("refresh", attrs))
1174
1175        event.listen(User, "load", load)
1176        event.listen(User, "refresh", refresh)
1177        return canary
1178
1179    def test_just_loaded(self):
1180        User = self.classes.User
1181
1182        canary = self._fixture()
1183
1184        sess = Session()
1185
1186        u1 = User(name="u1")
1187        sess.add(u1)
1188        sess.commit()
1189        sess.close()
1190
1191        sess.query(User).first()
1192        eq_(canary, ["load"])
1193
1194    def test_repeated_rows(self):
1195        User = self.classes.User
1196
1197        canary = self._fixture()
1198
1199        sess = Session()
1200
1201        u1 = User(name="u1")
1202        sess.add(u1)
1203        sess.commit()
1204        sess.close()
1205
1206        sess.query(User).union_all(sess.query(User)).all()
1207        eq_(canary, ["load"])
1208
1209
1210class RemovalTest(_fixtures.FixtureTest):
1211    run_inserts = None
1212
1213    def test_attr_propagated(self):
1214        User = self.classes.User
1215
1216        users, addresses, User = (
1217            self.tables.users,
1218            self.tables.addresses,
1219            self.classes.User,
1220        )
1221
1222        class AdminUser(User):
1223            pass
1224
1225        mapper(User, users)
1226        mapper(
1227            AdminUser,
1228            addresses,
1229            inherits=User,
1230            properties={"address_id": addresses.c.id},
1231        )
1232
1233        fn = Mock()
1234        event.listen(User.name, "set", fn, propagate=True)
1235
1236        au = AdminUser()
1237        au.name = "ed"
1238
1239        eq_(fn.call_count, 1)
1240
1241        event.remove(User.name, "set", fn)
1242
1243        au.name = "jack"
1244
1245        eq_(fn.call_count, 1)
1246
1247    def test_unmapped_listen(self):
1248        users = self.tables.users
1249
1250        class Foo(object):
1251            pass
1252
1253        fn = Mock()
1254
1255        event.listen(Foo, "before_insert", fn, propagate=True)
1256
1257        class User(Foo):
1258            pass
1259
1260        m = mapper(User, users)
1261
1262        u1 = User()
1263        m.dispatch.before_insert(m, None, attributes.instance_state(u1))
1264        eq_(fn.call_count, 1)
1265
1266        event.remove(Foo, "before_insert", fn)
1267
1268        # existing event is removed
1269        m.dispatch.before_insert(m, None, attributes.instance_state(u1))
1270        eq_(fn.call_count, 1)
1271
1272        # the _HoldEvents is also cleaned out
1273        class Bar(Foo):
1274            pass
1275
1276        m = mapper(Bar, users)
1277        b1 = Bar()
1278        m.dispatch.before_insert(m, None, attributes.instance_state(b1))
1279        eq_(fn.call_count, 1)
1280
1281    def test_instance_event_listen_on_cls_before_map(self):
1282        users = self.tables.users
1283
1284        fn = Mock()
1285
1286        class User(object):
1287            pass
1288
1289        event.listen(User, "load", fn)
1290        m = mapper(User, users)
1291
1292        u1 = User()
1293        m.class_manager.dispatch.load(u1._sa_instance_state, "u1")
1294
1295        event.remove(User, "load", fn)
1296
1297        m.class_manager.dispatch.load(u1._sa_instance_state, "u2")
1298
1299        eq_(fn.mock_calls, [call(u1, "u1")])
1300
1301
1302class RefreshTest(_fixtures.FixtureTest):
1303    run_inserts = None
1304
1305    @classmethod
1306    def setup_mappers(cls):
1307        User, users = cls.classes.User, cls.tables.users
1308
1309        mapper(User, users)
1310
1311    def _fixture(self):
1312        User = self.classes.User
1313
1314        canary = []
1315
1316        def load(target, ctx):
1317            canary.append("load")
1318
1319        def refresh(target, ctx, attrs):
1320            canary.append(("refresh", attrs))
1321
1322        event.listen(User, "load", load)
1323        event.listen(User, "refresh", refresh)
1324        return canary
1325
1326    def test_already_present(self):
1327        User = self.classes.User
1328
1329        canary = self._fixture()
1330
1331        sess = Session()
1332
1333        u1 = User(name="u1")
1334        sess.add(u1)
1335        sess.flush()
1336
1337        sess.query(User).first()
1338        eq_(canary, [])
1339
1340    def test_changes_reset(self):
1341        """test the contract of load/refresh such that history is reset.
1342
1343        This has never been an official contract but we are testing it
1344        here to ensure it is maintained given the loading performance
1345        enhancements.
1346
1347        """
1348        User = self.classes.User
1349
1350        @event.listens_for(User, "load")
1351        def canary1(obj, context):
1352            obj.name = "new name!"
1353
1354        @event.listens_for(User, "refresh")
1355        def canary2(obj, context, props):
1356            obj.name = "refreshed name!"
1357
1358        sess = Session()
1359        u1 = User(name="u1")
1360        sess.add(u1)
1361        sess.commit()
1362        sess.close()
1363
1364        u1 = sess.query(User).first()
1365        eq_(attributes.get_history(u1, "name"), ((), ["new name!"], ()))
1366        assert "name" not in attributes.instance_state(u1).committed_state
1367        assert u1 not in sess.dirty
1368
1369        sess.expire(u1)
1370        u1.id
1371        eq_(attributes.get_history(u1, "name"), ((), ["refreshed name!"], ()))
1372        assert "name" not in attributes.instance_state(u1).committed_state
1373        assert u1 in sess.dirty
1374
1375    def test_repeated_rows(self):
1376        User = self.classes.User
1377
1378        canary = self._fixture()
1379
1380        sess = Session()
1381
1382        u1 = User(name="u1")
1383        sess.add(u1)
1384        sess.commit()
1385
1386        sess.query(User).union_all(sess.query(User)).all()
1387        eq_(canary, [("refresh", set(["id", "name"]))])
1388
1389    def test_via_refresh_state(self):
1390        User = self.classes.User
1391
1392        canary = self._fixture()
1393
1394        sess = Session()
1395
1396        u1 = User(name="u1")
1397        sess.add(u1)
1398        sess.commit()
1399
1400        u1.name
1401        eq_(canary, [("refresh", set(["id", "name"]))])
1402
1403    def test_was_expired(self):
1404        User = self.classes.User
1405
1406        canary = self._fixture()
1407
1408        sess = Session()
1409
1410        u1 = User(name="u1")
1411        sess.add(u1)
1412        sess.flush()
1413        sess.expire(u1)
1414
1415        sess.query(User).first()
1416        eq_(canary, [("refresh", set(["id", "name"]))])
1417
1418    def test_was_expired_via_commit(self):
1419        User = self.classes.User
1420
1421        canary = self._fixture()
1422
1423        sess = Session()
1424
1425        u1 = User(name="u1")
1426        sess.add(u1)
1427        sess.commit()
1428
1429        sess.query(User).first()
1430        eq_(canary, [("refresh", set(["id", "name"]))])
1431
1432    def test_was_expired_attrs(self):
1433        User = self.classes.User
1434
1435        canary = self._fixture()
1436
1437        sess = Session()
1438
1439        u1 = User(name="u1")
1440        sess.add(u1)
1441        sess.flush()
1442        sess.expire(u1, ["name"])
1443
1444        sess.query(User).first()
1445        eq_(canary, [("refresh", set(["name"]))])
1446
1447    def test_populate_existing(self):
1448        User = self.classes.User
1449
1450        canary = self._fixture()
1451
1452        sess = Session()
1453
1454        u1 = User(name="u1")
1455        sess.add(u1)
1456        sess.commit()
1457
1458        sess.query(User).populate_existing().first()
1459        eq_(canary, [("refresh", None)])
1460
1461
1462class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest):
1463    run_inserts = None
1464
1465    def test_class_listen(self):
1466        def my_listener(*arg, **kw):
1467            pass
1468
1469        event.listen(Session, "before_flush", my_listener)
1470
1471        s = Session()
1472        assert my_listener in s.dispatch.before_flush
1473
1474    def test_sessionmaker_listen(self):
1475        """test that listen can be applied to individual
1476        scoped_session() classes."""
1477
1478        def my_listener_one(*arg, **kw):
1479            pass
1480
1481        def my_listener_two(*arg, **kw):
1482            pass
1483
1484        S1 = sessionmaker()
1485        S2 = sessionmaker()
1486
1487        event.listen(Session, "before_flush", my_listener_one)
1488        event.listen(S1, "before_flush", my_listener_two)
1489
1490        s1 = S1()
1491        assert my_listener_one in s1.dispatch.before_flush
1492        assert my_listener_two in s1.dispatch.before_flush
1493
1494        s2 = S2()
1495        assert my_listener_one in s2.dispatch.before_flush
1496        assert my_listener_two not in s2.dispatch.before_flush
1497
1498    def test_scoped_session_invalid_callable(self):
1499        from sqlalchemy.orm import scoped_session
1500
1501        def my_listener_one(*arg, **kw):
1502            pass
1503
1504        scope = scoped_session(lambda: Session())
1505
1506        assert_raises_message(
1507            sa.exc.ArgumentError,
1508            "Session event listen on a scoped_session requires that its "
1509            "creation callable is associated with the Session class.",
1510            event.listen,
1511            scope,
1512            "before_flush",
1513            my_listener_one,
1514        )
1515
1516    def test_scoped_session_invalid_class(self):
1517        from sqlalchemy.orm import scoped_session
1518
1519        def my_listener_one(*arg, **kw):
1520            pass
1521
1522        class NotASession(object):
1523            def __call__(self):
1524                return Session()
1525
1526        scope = scoped_session(NotASession)
1527
1528        assert_raises_message(
1529            sa.exc.ArgumentError,
1530            "Session event listen on a scoped_session requires that its "
1531            "creation callable is associated with the Session class.",
1532            event.listen,
1533            scope,
1534            "before_flush",
1535            my_listener_one,
1536        )
1537
1538    def test_scoped_session_listen(self):
1539        from sqlalchemy.orm import scoped_session
1540
1541        def my_listener_one(*arg, **kw):
1542            pass
1543
1544        scope = scoped_session(sessionmaker())
1545        event.listen(scope, "before_flush", my_listener_one)
1546
1547        assert my_listener_one in scope().dispatch.before_flush
1548
1549    def _listener_fixture(self, **kw):
1550        canary = []
1551
1552        def listener(name):
1553            def go(*arg, **kw):
1554                canary.append(name)
1555
1556            return go
1557
1558        sess = Session(**kw)
1559
1560        for evt in [
1561            "after_transaction_create",
1562            "after_transaction_end",
1563            "before_commit",
1564            "after_commit",
1565            "after_rollback",
1566            "after_soft_rollback",
1567            "before_flush",
1568            "after_flush",
1569            "after_flush_postexec",
1570            "after_begin",
1571            "before_attach",
1572            "after_attach",
1573            "after_bulk_update",
1574            "after_bulk_delete",
1575        ]:
1576            event.listen(sess, evt, listener(evt))
1577
1578        return sess, canary
1579
1580    def test_flush_autocommit_hook(self):
1581        User, users = self.classes.User, self.tables.users
1582
1583        mapper(User, users)
1584
1585        sess, canary = self._listener_fixture(
1586            autoflush=False, autocommit=True, expire_on_commit=False
1587        )
1588
1589        u = User(name="u1")
1590        sess.add(u)
1591        sess.flush()
1592        eq_(
1593            canary,
1594            [
1595                "before_attach",
1596                "after_attach",
1597                "before_flush",
1598                "after_transaction_create",
1599                "after_begin",
1600                "after_flush",
1601                "after_flush_postexec",
1602                "before_commit",
1603                "after_commit",
1604                "after_transaction_end",
1605            ],
1606        )
1607
1608    def test_rollback_hook(self):
1609        User, users = self.classes.User, self.tables.users
1610        sess, canary = self._listener_fixture()
1611        mapper(User, users)
1612
1613        u = User(name="u1", id=1)
1614        sess.add(u)
1615        sess.commit()
1616
1617        u2 = User(name="u1", id=1)
1618        sess.add(u2)
1619        assert_raises(sa.orm.exc.FlushError, sess.commit)
1620        sess.rollback()
1621        eq_(
1622            canary,
1623            [
1624                "before_attach",
1625                "after_attach",
1626                "before_commit",
1627                "before_flush",
1628                "after_transaction_create",
1629                "after_begin",
1630                "after_flush",
1631                "after_flush_postexec",
1632                "after_transaction_end",
1633                "after_commit",
1634                "after_transaction_end",
1635                "after_transaction_create",
1636                "before_attach",
1637                "after_attach",
1638                "before_commit",
1639                "before_flush",
1640                "after_transaction_create",
1641                "after_begin",
1642                "after_rollback",
1643                "after_transaction_end",
1644                "after_soft_rollback",
1645                "after_transaction_end",
1646                "after_transaction_create",
1647                "after_soft_rollback",
1648            ],
1649        )
1650
1651    def test_can_use_session_in_outer_rollback_hook(self):
1652        User, users = self.classes.User, self.tables.users
1653        mapper(User, users)
1654
1655        sess = Session()
1656
1657        assertions = []
1658
1659        @event.listens_for(sess, "after_soft_rollback")
1660        def do_something(session, previous_transaction):
1661            if session.is_active:
1662                assertions.append("name" not in u.__dict__)
1663                assertions.append(u.name == "u1")
1664
1665        u = User(name="u1", id=1)
1666        sess.add(u)
1667        sess.commit()
1668
1669        u2 = User(name="u1", id=1)
1670        sess.add(u2)
1671        assert_raises(sa.orm.exc.FlushError, sess.commit)
1672        sess.rollback()
1673        eq_(assertions, [True, True])
1674
1675    def test_flush_noautocommit_hook(self):
1676        User, users = self.classes.User, self.tables.users
1677
1678        sess, canary = self._listener_fixture()
1679
1680        mapper(User, users)
1681
1682        u = User(name="u1")
1683        sess.add(u)
1684        sess.flush()
1685        eq_(
1686            canary,
1687            [
1688                "before_attach",
1689                "after_attach",
1690                "before_flush",
1691                "after_transaction_create",
1692                "after_begin",
1693                "after_flush",
1694                "after_flush_postexec",
1695                "after_transaction_end",
1696            ],
1697        )
1698
1699    def test_flush_in_commit_hook(self):
1700        User, users = self.classes.User, self.tables.users
1701
1702        sess, canary = self._listener_fixture()
1703
1704        mapper(User, users)
1705        u = User(name="u1")
1706        sess.add(u)
1707        sess.flush()
1708        canary[:] = []
1709
1710        u.name = "ed"
1711        sess.commit()
1712        eq_(
1713            canary,
1714            [
1715                "before_commit",
1716                "before_flush",
1717                "after_transaction_create",
1718                "after_flush",
1719                "after_flush_postexec",
1720                "after_transaction_end",
1721                "after_commit",
1722                "after_transaction_end",
1723                "after_transaction_create",
1724            ],
1725        )
1726
1727    def test_state_before_attach(self):
1728        User, users = self.classes.User, self.tables.users
1729        sess = Session()
1730
1731        @event.listens_for(sess, "before_attach")
1732        def listener(session, inst):
1733            state = attributes.instance_state(inst)
1734            if state.key:
1735                assert state.key not in session.identity_map
1736            else:
1737                assert inst not in session.new
1738
1739        mapper(User, users)
1740        u = User(name="u1")
1741        sess.add(u)
1742        sess.flush()
1743        sess.expunge(u)
1744        sess.add(u)
1745
1746    def test_state_after_attach(self):
1747        User, users = self.classes.User, self.tables.users
1748        sess = Session()
1749
1750        @event.listens_for(sess, "after_attach")
1751        def listener(session, inst):
1752            state = attributes.instance_state(inst)
1753            if state.key:
1754                assert session.identity_map[state.key] is inst
1755            else:
1756                assert inst in session.new
1757
1758        mapper(User, users)
1759        u = User(name="u1")
1760        sess.add(u)
1761        sess.flush()
1762        sess.expunge(u)
1763        sess.add(u)
1764
1765    def test_standalone_on_commit_hook(self):
1766        sess, canary = self._listener_fixture()
1767        sess.commit()
1768        eq_(
1769            canary,
1770            [
1771                "before_commit",
1772                "after_commit",
1773                "after_transaction_end",
1774                "after_transaction_create",
1775            ],
1776        )
1777
1778    def test_on_bulk_update_hook(self):
1779        User, users = self.classes.User, self.tables.users
1780
1781        sess = Session()
1782        canary = Mock()
1783
1784        event.listen(sess, "after_begin", canary.after_begin)
1785        event.listen(sess, "after_bulk_update", canary.after_bulk_update)
1786
1787        def legacy(ses, qry, ctx, res):
1788            canary.after_bulk_update_legacy(ses, qry, ctx, res)
1789
1790        event.listen(sess, "after_bulk_update", legacy)
1791
1792        mapper(User, users)
1793
1794        sess.query(User).update({"name": "foo"})
1795
1796        eq_(canary.after_begin.call_count, 1)
1797        eq_(canary.after_bulk_update.call_count, 1)
1798
1799        upd = canary.after_bulk_update.mock_calls[0][1][0]
1800        eq_(upd.session, sess)
1801        eq_(
1802            canary.after_bulk_update_legacy.mock_calls,
1803            [call(sess, upd.query, upd.context, upd.result)],
1804        )
1805
1806    def test_on_bulk_delete_hook(self):
1807        User, users = self.classes.User, self.tables.users
1808
1809        sess = Session()
1810        canary = Mock()
1811
1812        event.listen(sess, "after_begin", canary.after_begin)
1813        event.listen(sess, "after_bulk_delete", canary.after_bulk_delete)
1814
1815        def legacy(ses, qry, ctx, res):
1816            canary.after_bulk_delete_legacy(ses, qry, ctx, res)
1817
1818        event.listen(sess, "after_bulk_delete", legacy)
1819
1820        mapper(User, users)
1821
1822        sess.query(User).delete()
1823
1824        eq_(canary.after_begin.call_count, 1)
1825        eq_(canary.after_bulk_delete.call_count, 1)
1826
1827        upd = canary.after_bulk_delete.mock_calls[0][1][0]
1828        eq_(upd.session, sess)
1829        eq_(
1830            canary.after_bulk_delete_legacy.mock_calls,
1831            [call(sess, upd.query, upd.context, upd.result)],
1832        )
1833
1834    def test_connection_emits_after_begin(self):
1835        sess, canary = self._listener_fixture(bind=testing.db)
1836        sess.connection()
1837        eq_(canary, ["after_begin"])
1838        sess.close()
1839
1840    def test_reentrant_flush(self):
1841        users, User = self.tables.users, self.classes.User
1842
1843        mapper(User, users)
1844
1845        def before_flush(session, flush_context, objects):
1846            session.flush()
1847
1848        sess = Session()
1849        event.listen(sess, "before_flush", before_flush)
1850        sess.add(User(name="foo"))
1851        assert_raises_message(
1852            sa.exc.InvalidRequestError, "already flushing", sess.flush
1853        )
1854
1855    def test_before_flush_affects_flush_plan(self):
1856        users, User = self.tables.users, self.classes.User
1857
1858        mapper(User, users)
1859
1860        def before_flush(session, flush_context, objects):
1861            for obj in list(session.new) + list(session.dirty):
1862                if isinstance(obj, User):
1863                    session.add(User(name="another %s" % obj.name))
1864            for obj in list(session.deleted):
1865                if isinstance(obj, User):
1866                    x = (
1867                        session.query(User)
1868                        .filter(User.name == "another %s" % obj.name)
1869                        .one()
1870                    )
1871                    session.delete(x)
1872
1873        sess = Session()
1874        event.listen(sess, "before_flush", before_flush)
1875
1876        u = User(name="u1")
1877        sess.add(u)
1878        sess.flush()
1879        eq_(
1880            sess.query(User).order_by(User.name).all(),
1881            [User(name="another u1"), User(name="u1")],
1882        )
1883
1884        sess.flush()
1885        eq_(
1886            sess.query(User).order_by(User.name).all(),
1887            [User(name="another u1"), User(name="u1")],
1888        )
1889
1890        u.name = "u2"
1891        sess.flush()
1892        eq_(
1893            sess.query(User).order_by(User.name).all(),
1894            [
1895                User(name="another u1"),
1896                User(name="another u2"),
1897                User(name="u2"),
1898            ],
1899        )
1900
1901        sess.delete(u)
1902        sess.flush()
1903        eq_(
1904            sess.query(User).order_by(User.name).all(),
1905            [User(name="another u1")],
1906        )
1907
1908    def test_before_flush_affects_dirty(self):
1909        users, User = self.tables.users, self.classes.User
1910
1911        mapper(User, users)
1912
1913        def before_flush(session, flush_context, objects):
1914            for obj in list(session.identity_map.values()):
1915                obj.name += " modified"
1916
1917        sess = Session(autoflush=True)
1918        event.listen(sess, "before_flush", before_flush)
1919
1920        u = User(name="u1")
1921        sess.add(u)
1922        sess.flush()
1923        eq_(sess.query(User).order_by(User.name).all(), [User(name="u1")])
1924
1925        sess.add(User(name="u2"))
1926        sess.flush()
1927        sess.expunge_all()
1928        eq_(
1929            sess.query(User).order_by(User.name).all(),
1930            [User(name="u1 modified"), User(name="u2")],
1931        )
1932
1933    def test_snapshot_still_present_after_commit(self):
1934        users, User = self.tables.users, self.classes.User
1935
1936        mapper(User, users)
1937
1938        sess = Session()
1939
1940        u1 = User(name="u1")
1941
1942        sess.add(u1)
1943        sess.commit()
1944
1945        u1 = sess.query(User).first()
1946
1947        @event.listens_for(sess, "after_commit")
1948        def assert_state(session):
1949            assert "name" in u1.__dict__
1950            eq_(u1.name, "u1")
1951
1952        sess.commit()
1953        assert "name" not in u1.__dict__
1954
1955    def test_snapshot_still_present_after_rollback(self):
1956        users, User = self.tables.users, self.classes.User
1957
1958        mapper(User, users)
1959
1960        sess = Session()
1961
1962        u1 = User(name="u1")
1963
1964        sess.add(u1)
1965        sess.commit()
1966
1967        u1 = sess.query(User).first()
1968
1969        @event.listens_for(sess, "after_rollback")
1970        def assert_state(session):
1971            assert "name" in u1.__dict__
1972            eq_(u1.name, "u1")
1973
1974        sess.rollback()
1975        assert "name" not in u1.__dict__
1976
1977
1978class SessionLifecycleEventsTest(_RemoveListeners, _fixtures.FixtureTest):
1979    run_inserts = None
1980
1981    def _fixture(self, include_address=False):
1982        users, User = self.tables.users, self.classes.User
1983
1984        if include_address:
1985            addresses, Address = self.tables.addresses, self.classes.Address
1986            mapper(
1987                User,
1988                users,
1989                properties={
1990                    "addresses": relationship(
1991                        Address, cascade="all, delete-orphan"
1992                    )
1993                },
1994            )
1995            mapper(Address, addresses)
1996        else:
1997            mapper(User, users)
1998
1999        listener = Mock()
2000
2001        sess = Session()
2002
2003        def start_events():
2004            event.listen(
2005                sess, "transient_to_pending", listener.transient_to_pending
2006            )
2007            event.listen(
2008                sess, "pending_to_transient", listener.pending_to_transient
2009            )
2010            event.listen(
2011                sess,
2012                "persistent_to_transient",
2013                listener.persistent_to_transient,
2014            )
2015            event.listen(
2016                sess, "pending_to_persistent", listener.pending_to_persistent
2017            )
2018            event.listen(
2019                sess, "detached_to_persistent", listener.detached_to_persistent
2020            )
2021            event.listen(
2022                sess, "loaded_as_persistent", listener.loaded_as_persistent
2023            )
2024
2025            event.listen(
2026                sess, "persistent_to_detached", listener.persistent_to_detached
2027            )
2028            event.listen(
2029                sess, "deleted_to_detached", listener.deleted_to_detached
2030            )
2031
2032            event.listen(
2033                sess, "persistent_to_deleted", listener.persistent_to_deleted
2034            )
2035            event.listen(
2036                sess, "deleted_to_persistent", listener.deleted_to_persistent
2037            )
2038            return listener
2039
2040        if include_address:
2041            return sess, User, Address, start_events
2042        else:
2043            return sess, User, start_events
2044
2045    def test_transient_to_pending(self):
2046        sess, User, start_events = self._fixture()
2047
2048        listener = start_events()
2049
2050        @event.listens_for(sess, "transient_to_pending")
2051        def trans_to_pending(session, instance):
2052            assert instance in session
2053            listener.flag_checked(instance)
2054
2055        u1 = User(name="u1")
2056        sess.add(u1)
2057
2058        eq_(
2059            listener.mock_calls,
2060            [call.transient_to_pending(sess, u1), call.flag_checked(u1)],
2061        )
2062
2063    def test_pending_to_transient_via_rollback(self):
2064        sess, User, start_events = self._fixture()
2065
2066        u1 = User(name="u1")
2067        sess.add(u1)
2068
2069        listener = start_events()
2070
2071        @event.listens_for(sess, "pending_to_transient")
2072        def test_deleted_flag(session, instance):
2073            assert instance not in session
2074            listener.flag_checked(instance)
2075
2076        sess.rollback()
2077        assert u1 not in sess
2078
2079        eq_(
2080            listener.mock_calls,
2081            [call.pending_to_transient(sess, u1), call.flag_checked(u1)],
2082        )
2083
2084    def test_pending_to_transient_via_expunge(self):
2085        sess, User, start_events = self._fixture()
2086
2087        u1 = User(name="u1")
2088        sess.add(u1)
2089
2090        listener = start_events()
2091
2092        @event.listens_for(sess, "pending_to_transient")
2093        def test_deleted_flag(session, instance):
2094            assert instance not in session
2095            listener.flag_checked(instance)
2096
2097        sess.expunge(u1)
2098        assert u1 not in sess
2099
2100        eq_(
2101            listener.mock_calls,
2102            [call.pending_to_transient(sess, u1), call.flag_checked(u1)],
2103        )
2104
2105    def test_pending_to_persistent(self):
2106        sess, User, start_events = self._fixture()
2107
2108        u1 = User(name="u1")
2109        sess.add(u1)
2110
2111        listener = start_events()
2112
2113        @event.listens_for(sess, "pending_to_persistent")
2114        def test_flag(session, instance):
2115            assert instance in session
2116            assert instance._sa_instance_state.persistent
2117            assert instance._sa_instance_state.key in session.identity_map
2118            listener.flag_checked(instance)
2119
2120        sess.flush()
2121
2122        eq_(
2123            listener.mock_calls,
2124            [call.pending_to_persistent(sess, u1), call.flag_checked(u1)],
2125        )
2126
2127        u1.name = "u2"
2128        sess.flush()
2129
2130        # event was not called again
2131        eq_(
2132            listener.mock_calls,
2133            [call.pending_to_persistent(sess, u1), call.flag_checked(u1)],
2134        )
2135
2136    def test_pending_to_persistent_del(self):
2137        sess, User, start_events = self._fixture()
2138
2139        @event.listens_for(sess, "pending_to_persistent")
2140        def pending_to_persistent(session, instance):
2141            listener.flag_checked(instance)
2142            # this is actually u1, because
2143            # we have a strong ref internally
2144            is_not(None, instance)
2145
2146        u1 = User(name="u1")
2147        sess.add(u1)
2148
2149        u1_inst_state = u1._sa_instance_state
2150        del u1
2151
2152        gc_collect()
2153
2154        listener = start_events()
2155
2156        sess.flush()
2157
2158        eq_(
2159            listener.mock_calls,
2160            [
2161                call.flag_checked(u1_inst_state.obj()),
2162                call.pending_to_persistent(sess, u1_inst_state.obj()),
2163            ],
2164        )
2165
2166    def test_persistent_to_deleted_del(self):
2167        sess, User, start_events = self._fixture()
2168
2169        u1 = User(name="u1")
2170        sess.add(u1)
2171        sess.flush()
2172
2173        listener = start_events()
2174
2175        @event.listens_for(sess, "persistent_to_deleted")
2176        def persistent_to_deleted(session, instance):
2177            is_not(None, instance)
2178            listener.flag_checked(instance)
2179
2180        sess.delete(u1)
2181        u1_inst_state = u1._sa_instance_state
2182
2183        del u1
2184        gc_collect()
2185
2186        sess.flush()
2187
2188        eq_(
2189            listener.mock_calls,
2190            [
2191                call.persistent_to_deleted(sess, u1_inst_state.obj()),
2192                call.flag_checked(u1_inst_state.obj()),
2193            ],
2194        )
2195
2196    def test_detached_to_persistent(self):
2197        sess, User, start_events = self._fixture()
2198
2199        u1 = User(name="u1")
2200        sess.add(u1)
2201        sess.flush()
2202
2203        sess.expunge(u1)
2204
2205        listener = start_events()
2206
2207        @event.listens_for(sess, "detached_to_persistent")
2208        def test_deleted_flag(session, instance):
2209            assert instance not in session.deleted
2210            assert instance in session
2211            listener.flag_checked()
2212
2213        sess.add(u1)
2214
2215        eq_(
2216            listener.mock_calls,
2217            [call.detached_to_persistent(sess, u1), call.flag_checked()],
2218        )
2219
2220    def test_loaded_as_persistent(self):
2221        sess, User, start_events = self._fixture()
2222
2223        u1 = User(name="u1")
2224        sess.add(u1)
2225        sess.commit()
2226        sess.close()
2227
2228        listener = start_events()
2229
2230        @event.listens_for(sess, "loaded_as_persistent")
2231        def test_identity_flag(session, instance):
2232            assert instance in session
2233            assert instance._sa_instance_state.persistent
2234            assert instance._sa_instance_state.key in session.identity_map
2235            assert not instance._sa_instance_state.deleted
2236            assert not instance._sa_instance_state.detached
2237            assert instance._sa_instance_state.persistent
2238            listener.flag_checked(instance)
2239
2240        u1 = sess.query(User).filter_by(name="u1").one()
2241
2242        eq_(
2243            listener.mock_calls,
2244            [call.loaded_as_persistent(sess, u1), call.flag_checked(u1)],
2245        )
2246
2247    def test_detached_to_persistent_via_deleted(self):
2248        sess, User, start_events = self._fixture()
2249
2250        u1 = User(name="u1")
2251        sess.add(u1)
2252        sess.commit()
2253        sess.close()
2254
2255        listener = start_events()
2256
2257        @event.listens_for(sess, "detached_to_persistent")
2258        def test_deleted_flag_persistent(session, instance):
2259            assert instance not in session.deleted
2260            assert instance in session
2261            assert not instance._sa_instance_state.deleted
2262            assert not instance._sa_instance_state.detached
2263            assert instance._sa_instance_state.persistent
2264            listener.dtp_flag_checked(instance)
2265
2266        @event.listens_for(sess, "persistent_to_deleted")
2267        def test_deleted_flag_detached(session, instance):
2268            assert instance not in session.deleted
2269            assert instance not in session
2270            assert not instance._sa_instance_state.persistent
2271            assert instance._sa_instance_state.deleted
2272            assert not instance._sa_instance_state.detached
2273            listener.ptd_flag_checked(instance)
2274
2275        sess.delete(u1)
2276        assert u1 in sess.deleted
2277
2278        eq_(
2279            listener.mock_calls,
2280            [call.detached_to_persistent(sess, u1), call.dtp_flag_checked(u1)],
2281        )
2282
2283        sess.flush()
2284
2285        eq_(
2286            listener.mock_calls,
2287            [
2288                call.detached_to_persistent(sess, u1),
2289                call.dtp_flag_checked(u1),
2290                call.persistent_to_deleted(sess, u1),
2291                call.ptd_flag_checked(u1),
2292            ],
2293        )
2294
2295    def test_detached_to_persistent_via_cascaded_delete(self):
2296        sess, User, Address, start_events = self._fixture(include_address=True)
2297
2298        u1 = User(name="u1")
2299        sess.add(u1)
2300        a1 = Address(email_address="e1")
2301        u1.addresses.append(a1)
2302        sess.commit()
2303        u1.addresses  # ensure u1.addresses refers to a1 before detachment
2304        sess.close()
2305
2306        listener = start_events()
2307
2308        @event.listens_for(sess, "detached_to_persistent")
2309        def test_deleted_flag(session, instance):
2310            assert instance not in session.deleted
2311            assert instance in session
2312            assert not instance._sa_instance_state.deleted
2313            assert not instance._sa_instance_state.detached
2314            assert instance._sa_instance_state.persistent
2315            listener.flag_checked(instance)
2316
2317        sess.delete(u1)
2318        assert u1 in sess.deleted
2319        assert a1 in sess.deleted
2320
2321        eq_(
2322            listener.mock_calls,
2323            [
2324                call.detached_to_persistent(sess, u1),
2325                call.flag_checked(u1),
2326                call.detached_to_persistent(sess, a1),
2327                call.flag_checked(a1),
2328            ],
2329        )
2330
2331        sess.flush()
2332
2333    def test_persistent_to_deleted(self):
2334        sess, User, start_events = self._fixture()
2335
2336        u1 = User(name="u1")
2337        sess.add(u1)
2338        sess.commit()
2339
2340        listener = start_events()
2341
2342        @event.listens_for(sess, "persistent_to_deleted")
2343        def test_deleted_flag(session, instance):
2344            assert instance not in session.deleted
2345            assert instance not in session
2346            assert instance._sa_instance_state.deleted
2347            assert not instance._sa_instance_state.detached
2348            assert not instance._sa_instance_state.persistent
2349            listener.flag_checked(instance)
2350
2351        sess.delete(u1)
2352        assert u1 in sess.deleted
2353
2354        eq_(listener.mock_calls, [])
2355
2356        sess.flush()
2357        assert u1 not in sess
2358
2359        eq_(
2360            listener.mock_calls,
2361            [call.persistent_to_deleted(sess, u1), call.flag_checked(u1)],
2362        )
2363
2364    def test_persistent_to_detached_via_expunge(self):
2365        sess, User, start_events = self._fixture()
2366
2367        u1 = User(name="u1")
2368        sess.add(u1)
2369        sess.flush()
2370
2371        listener = start_events()
2372
2373        @event.listens_for(sess, "persistent_to_detached")
2374        def test_deleted_flag(session, instance):
2375            assert instance not in session.deleted
2376            assert instance not in session
2377            assert not instance._sa_instance_state.deleted
2378            assert instance._sa_instance_state.detached
2379            assert not instance._sa_instance_state.persistent
2380            listener.flag_checked(instance)
2381
2382        assert u1 in sess
2383        sess.expunge(u1)
2384        assert u1 not in sess
2385
2386        eq_(
2387            listener.mock_calls,
2388            [call.persistent_to_detached(sess, u1), call.flag_checked(u1)],
2389        )
2390
2391    def test_persistent_to_detached_via_expunge_all(self):
2392        sess, User, start_events = self._fixture()
2393
2394        u1 = User(name="u1")
2395        sess.add(u1)
2396        sess.flush()
2397
2398        listener = start_events()
2399
2400        @event.listens_for(sess, "persistent_to_detached")
2401        def test_deleted_flag(session, instance):
2402            assert instance not in session.deleted
2403            assert instance not in session
2404            assert not instance._sa_instance_state.deleted
2405            assert instance._sa_instance_state.detached
2406            assert not instance._sa_instance_state.persistent
2407            listener.flag_checked(instance)
2408
2409        assert u1 in sess
2410        sess.expunge_all()
2411        assert u1 not in sess
2412
2413        eq_(
2414            listener.mock_calls,
2415            [call.persistent_to_detached(sess, u1), call.flag_checked(u1)],
2416        )
2417
2418    def test_persistent_to_transient_via_rollback(self):
2419        sess, User, start_events = self._fixture()
2420
2421        u1 = User(name="u1")
2422        sess.add(u1)
2423        sess.flush()
2424
2425        listener = start_events()
2426
2427        @event.listens_for(sess, "persistent_to_transient")
2428        def test_deleted_flag(session, instance):
2429            assert instance not in session.deleted
2430            assert instance not in session
2431            assert not instance._sa_instance_state.deleted
2432            assert not instance._sa_instance_state.detached
2433            assert not instance._sa_instance_state.persistent
2434            assert instance._sa_instance_state.transient
2435            listener.flag_checked(instance)
2436
2437        sess.rollback()
2438
2439        eq_(
2440            listener.mock_calls,
2441            [call.persistent_to_transient(sess, u1), call.flag_checked(u1)],
2442        )
2443
2444    def test_deleted_to_persistent_via_rollback(self):
2445        sess, User, start_events = self._fixture()
2446
2447        u1 = User(name="u1")
2448        sess.add(u1)
2449        sess.commit()
2450
2451        sess.delete(u1)
2452        sess.flush()
2453
2454        listener = start_events()
2455
2456        @event.listens_for(sess, "deleted_to_persistent")
2457        def test_deleted_flag(session, instance):
2458            assert instance not in session.deleted
2459            assert instance in session
2460            assert not instance._sa_instance_state.deleted
2461            assert not instance._sa_instance_state.detached
2462            assert instance._sa_instance_state.persistent
2463            listener.flag_checked(instance)
2464
2465        assert u1 not in sess
2466        assert u1._sa_instance_state.deleted
2467        assert not u1._sa_instance_state.persistent
2468        assert not u1._sa_instance_state.detached
2469
2470        sess.rollback()
2471
2472        assert u1 in sess
2473        assert u1._sa_instance_state.persistent
2474        assert not u1._sa_instance_state.deleted
2475        assert not u1._sa_instance_state.detached
2476
2477        eq_(
2478            listener.mock_calls,
2479            [call.deleted_to_persistent(sess, u1), call.flag_checked(u1)],
2480        )
2481
2482    def test_deleted_to_detached_via_commit(self):
2483        sess, User, start_events = self._fixture()
2484
2485        u1 = User(name="u1")
2486        sess.add(u1)
2487        sess.commit()
2488
2489        sess.delete(u1)
2490        sess.flush()
2491
2492        listener = start_events()
2493
2494        @event.listens_for(sess, "deleted_to_detached")
2495        def test_detached_flag(session, instance):
2496            assert instance not in session.deleted
2497            assert instance not in session
2498            assert not instance._sa_instance_state.deleted
2499            assert instance._sa_instance_state.detached
2500            listener.flag_checked(instance)
2501
2502        assert u1 not in sess
2503        assert u1._sa_instance_state.deleted
2504        assert not u1._sa_instance_state.persistent
2505        assert not u1._sa_instance_state.detached
2506
2507        sess.commit()
2508
2509        assert u1 not in sess
2510        assert not u1._sa_instance_state.deleted
2511        assert u1._sa_instance_state.detached
2512
2513        eq_(
2514            listener.mock_calls,
2515            [call.deleted_to_detached(sess, u1), call.flag_checked(u1)],
2516        )
2517
2518
2519class QueryEventsTest(
2520    _RemoveListeners,
2521    _fixtures.FixtureTest,
2522    AssertsCompiledSQL,
2523    testing.AssertsExecutionResults,
2524):
2525    __dialect__ = "default"
2526
2527    @classmethod
2528    def setup_mappers(cls):
2529        User = cls.classes.User
2530        users = cls.tables.users
2531
2532        mapper(User, users)
2533
2534    def test_before_compile(self):
2535        @event.listens_for(query.Query, "before_compile", retval=True)
2536        def no_deleted(query):
2537            for desc in query.column_descriptions:
2538                if desc["type"] is User:
2539                    entity = desc["expr"]
2540                    query = query.filter(entity.id != 10)
2541            return query
2542
2543        User = self.classes.User
2544        s = Session()
2545
2546        q = s.query(User).filter_by(id=7)
2547        self.assert_compile(
2548            q,
2549            "SELECT users.id AS users_id, users.name AS users_name "
2550            "FROM users "
2551            "WHERE users.id = :id_1 AND users.id != :id_2",
2552            checkparams={"id_2": 10, "id_1": 7},
2553        )
2554
2555    def test_before_compile_no_retval(self):
2556        counter = [0]
2557
2558        @event.listens_for(query.Query, "before_compile")
2559        def count(query):
2560            counter[0] += 1
2561
2562        User = self.classes.User
2563        s = Session()
2564
2565        q = s.query(User).filter_by(id=7)
2566        str(q)
2567        str(q)
2568        eq_(counter, [2])
2569
2570    def test_alters_entities(self):
2571        User = self.classes.User
2572
2573        @event.listens_for(query.Query, "before_compile", retval=True)
2574        def fn(query):
2575            return query.add_columns(User.name)
2576
2577        s = Session()
2578
2579        q = s.query(User.id).filter_by(id=7)
2580        self.assert_compile(
2581            q,
2582            "SELECT users.id AS users_id, users.name AS users_name "
2583            "FROM users "
2584            "WHERE users.id = :id_1",
2585            checkparams={"id_1": 7},
2586        )
2587        eq_(q.all(), [(7, "jack")])
2588
2589    def test_before_compile_update(self):
2590        @event.listens_for(query.Query, "before_compile_update", retval=True)
2591        def no_deleted(query, update_context):
2592            assert update_context.query is query
2593
2594            for desc in query.column_descriptions:
2595                if desc["type"] is User:
2596                    entity = desc["expr"]
2597                    query = query.filter(entity.id != 10)
2598
2599                    update_context.values["name"] = (
2600                        update_context.values["name"] + "_modified"
2601                    )
2602            return query
2603
2604        User = self.classes.User
2605        s = Session()
2606
2607        with self.sql_execution_asserter() as asserter:
2608            s.query(User).filter_by(id=7).update({"name": "ed"})
2609        asserter.assert_(
2610            CompiledSQL(
2611                "UPDATE users SET name=:name WHERE "
2612                "users.id = :id_1 AND users.id != :id_2",
2613                [{"name": "ed_modified", "id_1": 7, "id_2": 10}],
2614            )
2615        )
2616
2617    def test_before_compile_delete(self):
2618        @event.listens_for(query.Query, "before_compile_delete", retval=True)
2619        def no_deleted(query, delete_context):
2620            assert delete_context.query is query
2621
2622            for desc in query.column_descriptions:
2623                if desc["type"] is User:
2624                    entity = desc["expr"]
2625                    query = query.filter(entity.id != 10)
2626            return query
2627
2628        User = self.classes.User
2629        s = Session()
2630
2631        # note this deletes no rows
2632        with self.sql_execution_asserter() as asserter:
2633            s.query(User).filter_by(id=10).delete()
2634        asserter.assert_(
2635            CompiledSQL(
2636                "DELETE FROM users WHERE "
2637                "users.id = :id_1 AND users.id != :id_2",
2638                [{"id_1": 10, "id_2": 10}],
2639            )
2640        )
2641
2642
2643class RefreshFlushInReturningTest(fixtures.MappedTest):
2644    """test [ticket:3427].
2645
2646    this is a rework of the test for [ticket:3167] stated
2647    in test_unitofworkv2, which tests that returning doesn't trigger
2648    attribute events; the test here is *reversed* so that we test that
2649    it *does* trigger the new refresh_flush event.
2650
2651    """
2652
2653    __backend__ = True
2654
2655    @classmethod
2656    def define_tables(cls, metadata):
2657        Table(
2658            "test",
2659            metadata,
2660            Column(
2661                "id", Integer, primary_key=True, test_needs_autoincrement=True
2662            ),
2663            Column("prefetch_val", Integer, default=5),
2664            Column("returning_val", Integer, server_default="5"),
2665        )
2666
2667    @classmethod
2668    def setup_classes(cls):
2669        class Thing(cls.Basic):
2670            pass
2671
2672    @classmethod
2673    def setup_mappers(cls):
2674        Thing = cls.classes.Thing
2675
2676        mapper(Thing, cls.tables.test, eager_defaults=True)
2677
2678    def test_no_attr_events_flush(self):
2679        Thing = self.classes.Thing
2680        mock = Mock()
2681        event.listen(Thing, "refresh_flush", mock)
2682        t1 = Thing()
2683        s = Session()
2684        s.add(t1)
2685        s.flush()
2686
2687        if testing.requires.returning.enabled:
2688            # ordering is deterministic in this test b.c. the routine
2689            # appends the "returning" params before the "prefetch"
2690            # ones.  if there were more than one attribute in each category,
2691            # then we'd have hash order issues.
2692            eq_(
2693                mock.mock_calls,
2694                [call(t1, ANY, ["returning_val", "prefetch_val"])],
2695            )
2696        else:
2697            eq_(mock.mock_calls, [call(t1, ANY, ["prefetch_val"])])
2698
2699        eq_(t1.id, 1)
2700        eq_(t1.prefetch_val, 5)
2701        eq_(t1.returning_val, 5)
2702