1# orm/unitofwork.py
2# Copyright (C) 2005-2018 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8"""The internals for the unit of work system.
9
10The session's flush() process passes objects to a contextual object
11here, which assembles flush tasks based on mappers and their properties,
12organizes them in order of dependency, and executes.
13
14"""
15
16from .. import util, event
17from ..util import topological
18from . import attributes, persistence, util as orm_util
19from . import exc as orm_exc
20import itertools
21
22
23def track_cascade_events(descriptor, prop):
24    """Establish event listeners on object attributes which handle
25    cascade-on-set/append.
26
27    """
28    key = prop.key
29
30    def append(state, item, initiator):
31        # process "save_update" cascade rules for when
32        # an instance is appended to the list of another instance
33
34        if item is None:
35            return
36
37        sess = state.session
38        if sess:
39            if sess._warn_on_events:
40                sess._flush_warning("collection append")
41
42            prop = state.manager.mapper._props[key]
43            item_state = attributes.instance_state(item)
44            if prop._cascade.save_update and \
45                (prop.cascade_backrefs or key == initiator.key) and \
46                    not sess._contains_state(item_state):
47                sess._save_or_update_state(item_state)
48        return item
49
50    def remove(state, item, initiator):
51        if item is None:
52            return
53
54        sess = state.session
55        if sess:
56
57            prop = state.manager.mapper._props[key]
58
59            if sess._warn_on_events:
60                sess._flush_warning(
61                    "collection remove"
62                    if prop.uselist
63                    else "related attribute delete")
64
65            # expunge pending orphans
66            item_state = attributes.instance_state(item)
67            if prop._cascade.delete_orphan and \
68                item_state in sess._new and \
69                    prop.mapper._is_orphan(item_state):
70                sess.expunge(item)
71
72    def set_(state, newvalue, oldvalue, initiator):
73        # process "save_update" cascade rules for when an instance
74        # is attached to another instance
75        if oldvalue is newvalue:
76            return newvalue
77
78        sess = state.session
79        if sess:
80
81            if sess._warn_on_events:
82                sess._flush_warning("related attribute set")
83
84            prop = state.manager.mapper._props[key]
85            if newvalue is not None:
86                newvalue_state = attributes.instance_state(newvalue)
87                if prop._cascade.save_update and \
88                        (prop.cascade_backrefs or key == initiator.key) and \
89                        not sess._contains_state(newvalue_state):
90                    sess._save_or_update_state(newvalue_state)
91
92            if oldvalue is not None and \
93                oldvalue is not attributes.NEVER_SET and \
94                oldvalue is not attributes.PASSIVE_NO_RESULT and \
95                    prop._cascade.delete_orphan:
96                # possible to reach here with attributes.NEVER_SET ?
97                oldvalue_state = attributes.instance_state(oldvalue)
98
99                if oldvalue_state in sess._new and \
100                        prop.mapper._is_orphan(oldvalue_state):
101                    sess.expunge(oldvalue)
102        return newvalue
103
104    event.listen(descriptor, 'append', append, raw=True, retval=True)
105    event.listen(descriptor, 'remove', remove, raw=True, retval=True)
106    event.listen(descriptor, 'set', set_, raw=True, retval=True)
107
108
109class UOWTransaction(object):
110    def __init__(self, session):
111        self.session = session
112
113        # dictionary used by external actors to
114        # store arbitrary state information.
115        self.attributes = {}
116
117        # dictionary of mappers to sets of
118        # DependencyProcessors, which are also
119        # set to be part of the sorted flush actions,
120        # which have that mapper as a parent.
121        self.deps = util.defaultdict(set)
122
123        # dictionary of mappers to sets of InstanceState
124        # items pending for flush which have that mapper
125        # as a parent.
126        self.mappers = util.defaultdict(set)
127
128        # a dictionary of Preprocess objects, which gather
129        # additional states impacted by the flush
130        # and determine if a flush action is needed
131        self.presort_actions = {}
132
133        # dictionary of PostSortRec objects, each
134        # one issues work during the flush within
135        # a certain ordering.
136        self.postsort_actions = {}
137
138        # a set of 2-tuples, each containing two
139        # PostSortRec objects where the second
140        # is dependent on the first being executed
141        # first
142        self.dependencies = set()
143
144        # dictionary of InstanceState-> (isdelete, listonly)
145        # tuples, indicating if this state is to be deleted
146        # or insert/updated, or just refreshed
147        self.states = {}
148
149        # tracks InstanceStates which will be receiving
150        # a "post update" call.  Keys are mappers,
151        # values are a set of states and a set of the
152        # columns which should be included in the update.
153        self.post_update_states = util.defaultdict(lambda: (set(), set()))
154
155    @property
156    def has_work(self):
157        return bool(self.states)
158
159    def was_already_deleted(self, state):
160        """return true if the given state is expired and was deleted
161        previously.
162        """
163        if state.expired:
164            try:
165                state._load_expired(state, attributes.PASSIVE_OFF)
166            except orm_exc.ObjectDeletedError:
167                self.session._remove_newly_deleted([state])
168                return True
169        return False
170
171    def is_deleted(self, state):
172        """return true if the given state is marked as deleted
173        within this uowtransaction."""
174
175        return state in self.states and self.states[state][0]
176
177    def memo(self, key, callable_):
178        if key in self.attributes:
179            return self.attributes[key]
180        else:
181            self.attributes[key] = ret = callable_()
182            return ret
183
184    def remove_state_actions(self, state):
185        """remove pending actions for a state from the uowtransaction."""
186
187        isdelete = self.states[state][0]
188
189        self.states[state] = (isdelete, True)
190
191    def get_attribute_history(self, state, key,
192                              passive=attributes.PASSIVE_NO_INITIALIZE):
193        """facade to attributes.get_state_history(), including
194        caching of results."""
195
196        hashkey = ("history", state, key)
197
198        # cache the objects, not the states; the strong reference here
199        # prevents newly loaded objects from being dereferenced during the
200        # flush process
201
202        if hashkey in self.attributes:
203            history, state_history, cached_passive = self.attributes[hashkey]
204            # if the cached lookup was "passive" and now
205            # we want non-passive, do a non-passive lookup and re-cache
206
207            if not cached_passive & attributes.SQL_OK \
208                    and passive & attributes.SQL_OK:
209                impl = state.manager[key].impl
210                history = impl.get_history(state, state.dict,
211                                           attributes.PASSIVE_OFF |
212                                           attributes.LOAD_AGAINST_COMMITTED)
213                if history and impl.uses_objects:
214                    state_history = history.as_state()
215                else:
216                    state_history = history
217                self.attributes[hashkey] = (history, state_history, passive)
218        else:
219            impl = state.manager[key].impl
220            # TODO: store the history as (state, object) tuples
221            # so we don't have to keep converting here
222            history = impl.get_history(state, state.dict, passive |
223                                       attributes.LOAD_AGAINST_COMMITTED)
224            if history and impl.uses_objects:
225                state_history = history.as_state()
226            else:
227                state_history = history
228            self.attributes[hashkey] = (history, state_history,
229                                        passive)
230
231        return state_history
232
233    def has_dep(self, processor):
234        return (processor, True) in self.presort_actions
235
236    def register_preprocessor(self, processor, fromparent):
237        key = (processor, fromparent)
238        if key not in self.presort_actions:
239            self.presort_actions[key] = Preprocess(processor, fromparent)
240
241    def register_object(self, state, isdelete=False,
242                        listonly=False, cancel_delete=False,
243                        operation=None, prop=None):
244        if not self.session._contains_state(state):
245            # this condition is normal when objects are registered
246            # as part of a relationship cascade operation.  it should
247            # not occur for the top-level register from Session.flush().
248            if not state.deleted and operation is not None:
249                util.warn("Object of type %s not in session, %s operation "
250                          "along '%s' will not proceed" %
251                          (orm_util.state_class_str(state), operation, prop))
252            return False
253
254        if state not in self.states:
255            mapper = state.manager.mapper
256
257            if mapper not in self.mappers:
258                self._per_mapper_flush_actions(mapper)
259
260            self.mappers[mapper].add(state)
261            self.states[state] = (isdelete, listonly)
262        else:
263            if not listonly and (isdelete or cancel_delete):
264                self.states[state] = (isdelete, False)
265        return True
266
267    def issue_post_update(self, state, post_update_cols):
268        mapper = state.manager.mapper.base_mapper
269        states, cols = self.post_update_states[mapper]
270        states.add(state)
271        cols.update(post_update_cols)
272
273    def _per_mapper_flush_actions(self, mapper):
274        saves = SaveUpdateAll(self, mapper.base_mapper)
275        deletes = DeleteAll(self, mapper.base_mapper)
276        self.dependencies.add((saves, deletes))
277
278        for dep in mapper._dependency_processors:
279            dep.per_property_preprocessors(self)
280
281        for prop in mapper.relationships:
282            if prop.viewonly:
283                continue
284            dep = prop._dependency_processor
285            dep.per_property_preprocessors(self)
286
287    @util.memoized_property
288    def _mapper_for_dep(self):
289        """return a dynamic mapping of (Mapper, DependencyProcessor) to
290        True or False, indicating if the DependencyProcessor operates
291        on objects of that Mapper.
292
293        The result is stored in the dictionary persistently once
294        calculated.
295
296        """
297        return util.PopulateDict(
298            lambda tup: tup[0]._props.get(tup[1].key) is tup[1].prop
299        )
300
301    def filter_states_for_dep(self, dep, states):
302        """Filter the given list of InstanceStates to those relevant to the
303        given DependencyProcessor.
304
305        """
306        mapper_for_dep = self._mapper_for_dep
307        return [s for s in states if mapper_for_dep[(s.manager.mapper, dep)]]
308
309    def states_for_mapper_hierarchy(self, mapper, isdelete, listonly):
310        checktup = (isdelete, listonly)
311        for mapper in mapper.base_mapper.self_and_descendants:
312            for state in self.mappers[mapper]:
313                if self.states[state] == checktup:
314                    yield state
315
316    def _generate_actions(self):
317        """Generate the full, unsorted collection of PostSortRecs as
318        well as dependency pairs for this UOWTransaction.
319
320        """
321        # execute presort_actions, until all states
322        # have been processed.   a presort_action might
323        # add new states to the uow.
324        while True:
325            ret = False
326            for action in list(self.presort_actions.values()):
327                if action.execute(self):
328                    ret = True
329            if not ret:
330                break
331
332        # see if the graph of mapper dependencies has cycles.
333        self.cycles = cycles = topological.find_cycles(
334            self.dependencies,
335            list(self.postsort_actions.values()))
336
337        if cycles:
338            # if yes, break the per-mapper actions into
339            # per-state actions
340            convert = dict(
341                (rec, set(rec.per_state_flush_actions(self)))
342                for rec in cycles
343            )
344
345            # rewrite the existing dependencies to point to
346            # the per-state actions for those per-mapper actions
347            # that were broken up.
348            for edge in list(self.dependencies):
349                if None in edge or \
350                        edge[0].disabled or edge[1].disabled or \
351                        cycles.issuperset(edge):
352                    self.dependencies.remove(edge)
353                elif edge[0] in cycles:
354                    self.dependencies.remove(edge)
355                    for dep in convert[edge[0]]:
356                        self.dependencies.add((dep, edge[1]))
357                elif edge[1] in cycles:
358                    self.dependencies.remove(edge)
359                    for dep in convert[edge[1]]:
360                        self.dependencies.add((edge[0], dep))
361
362        return set([a for a in self.postsort_actions.values()
363                    if not a.disabled
364                    ]
365                   ).difference(cycles)
366
367    def execute(self):
368        postsort_actions = self._generate_actions()
369
370        # sort = topological.sort(self.dependencies, postsort_actions)
371        # print "--------------"
372        # print "\ndependencies:", self.dependencies
373        # print "\ncycles:", self.cycles
374        # print "\nsort:", list(sort)
375        # print "\nCOUNT OF POSTSORT ACTIONS", len(postsort_actions)
376
377        # execute
378        if self.cycles:
379            for set_ in topological.sort_as_subsets(
380                    self.dependencies,
381                    postsort_actions):
382                while set_:
383                    n = set_.pop()
384                    n.execute_aggregate(self, set_)
385        else:
386            for rec in topological.sort(
387                    self.dependencies,
388                    postsort_actions):
389                rec.execute(self)
390
391    def finalize_flush_changes(self):
392        """mark processed objects as clean / deleted after a successful
393        flush().
394
395        this method is called within the flush() method after the
396        execute() method has succeeded and the transaction has been committed.
397
398        """
399        if not self.states:
400            return
401
402        states = set(self.states)
403        isdel = set(
404            s for (s, (isdelete, listonly)) in self.states.items()
405            if isdelete
406        )
407        other = states.difference(isdel)
408        if isdel:
409            self.session._remove_newly_deleted(isdel)
410        if other:
411            self.session._register_newly_persistent(other)
412
413
414class IterateMappersMixin(object):
415    def _mappers(self, uow):
416        if self.fromparent:
417            return iter(
418                m for m in
419                self.dependency_processor.parent.self_and_descendants
420                if uow._mapper_for_dep[(m, self.dependency_processor)]
421            )
422        else:
423            return self.dependency_processor.mapper.self_and_descendants
424
425
426class Preprocess(IterateMappersMixin):
427    def __init__(self, dependency_processor, fromparent):
428        self.dependency_processor = dependency_processor
429        self.fromparent = fromparent
430        self.processed = set()
431        self.setup_flush_actions = False
432
433    def execute(self, uow):
434        delete_states = set()
435        save_states = set()
436
437        for mapper in self._mappers(uow):
438            for state in uow.mappers[mapper].difference(self.processed):
439                (isdelete, listonly) = uow.states[state]
440                if not listonly:
441                    if isdelete:
442                        delete_states.add(state)
443                    else:
444                        save_states.add(state)
445
446        if delete_states:
447            self.dependency_processor.presort_deletes(uow, delete_states)
448            self.processed.update(delete_states)
449        if save_states:
450            self.dependency_processor.presort_saves(uow, save_states)
451            self.processed.update(save_states)
452
453        if (delete_states or save_states):
454            if not self.setup_flush_actions and (
455                    self.dependency_processor.
456                    prop_has_changes(uow, delete_states, True) or
457                    self.dependency_processor.
458                    prop_has_changes(uow, save_states, False)
459            ):
460                self.dependency_processor.per_property_flush_actions(uow)
461                self.setup_flush_actions = True
462            return True
463        else:
464            return False
465
466
467class PostSortRec(object):
468    disabled = False
469
470    def __new__(cls, uow, *args):
471        key = (cls, ) + args
472        if key in uow.postsort_actions:
473            return uow.postsort_actions[key]
474        else:
475            uow.postsort_actions[key] = \
476                ret = \
477                object.__new__(cls)
478            return ret
479
480    def execute_aggregate(self, uow, recs):
481        self.execute(uow)
482
483    def __repr__(self):
484        return "%s(%s)" % (
485            self.__class__.__name__,
486            ",".join(str(x) for x in self.__dict__.values())
487        )
488
489
490class ProcessAll(IterateMappersMixin, PostSortRec):
491    def __init__(self, uow, dependency_processor, delete, fromparent):
492        self.dependency_processor = dependency_processor
493        self.delete = delete
494        self.fromparent = fromparent
495        uow.deps[dependency_processor.parent.base_mapper].\
496            add(dependency_processor)
497
498    def execute(self, uow):
499        states = self._elements(uow)
500        if self.delete:
501            self.dependency_processor.process_deletes(uow, states)
502        else:
503            self.dependency_processor.process_saves(uow, states)
504
505    def per_state_flush_actions(self, uow):
506        # this is handled by SaveUpdateAll and DeleteAll,
507        # since a ProcessAll should unconditionally be pulled
508        # into per-state if either the parent/child mappers
509        # are part of a cycle
510        return iter([])
511
512    def __repr__(self):
513        return "%s(%s, delete=%s)" % (
514            self.__class__.__name__,
515            self.dependency_processor,
516            self.delete
517        )
518
519    def _elements(self, uow):
520        for mapper in self._mappers(uow):
521            for state in uow.mappers[mapper]:
522                (isdelete, listonly) = uow.states[state]
523                if isdelete == self.delete and not listonly:
524                    yield state
525
526
527class IssuePostUpdate(PostSortRec):
528    def __init__(self, uow, mapper, isdelete):
529        self.mapper = mapper
530        self.isdelete = isdelete
531
532    def execute(self, uow):
533        states, cols = uow.post_update_states[self.mapper]
534        states = [s for s in states if uow.states[s][0] == self.isdelete]
535
536        persistence.post_update(self.mapper, states, uow, cols)
537
538
539class SaveUpdateAll(PostSortRec):
540    def __init__(self, uow, mapper):
541        self.mapper = mapper
542        assert mapper is mapper.base_mapper
543
544    def execute(self, uow):
545        persistence.save_obj(self.mapper,
546                             uow.states_for_mapper_hierarchy(
547                                 self.mapper, False, False),
548                             uow
549                             )
550
551    def per_state_flush_actions(self, uow):
552        states = list(uow.states_for_mapper_hierarchy(
553            self.mapper, False, False))
554        base_mapper = self.mapper.base_mapper
555        delete_all = DeleteAll(uow, base_mapper)
556        for state in states:
557            # keep saves before deletes -
558            # this ensures 'row switch' operations work
559            action = SaveUpdateState(uow, state, base_mapper)
560            uow.dependencies.add((action, delete_all))
561            yield action
562
563        for dep in uow.deps[self.mapper]:
564            states_for_prop = uow.filter_states_for_dep(dep, states)
565            dep.per_state_flush_actions(uow, states_for_prop, False)
566
567
568class DeleteAll(PostSortRec):
569    def __init__(self, uow, mapper):
570        self.mapper = mapper
571        assert mapper is mapper.base_mapper
572
573    def execute(self, uow):
574        persistence.delete_obj(self.mapper,
575                               uow.states_for_mapper_hierarchy(
576                                   self.mapper, True, False),
577                               uow
578                               )
579
580    def per_state_flush_actions(self, uow):
581        states = list(uow.states_for_mapper_hierarchy(
582            self.mapper, True, False))
583        base_mapper = self.mapper.base_mapper
584        save_all = SaveUpdateAll(uow, base_mapper)
585        for state in states:
586            # keep saves before deletes -
587            # this ensures 'row switch' operations work
588            action = DeleteState(uow, state, base_mapper)
589            uow.dependencies.add((save_all, action))
590            yield action
591
592        for dep in uow.deps[self.mapper]:
593            states_for_prop = uow.filter_states_for_dep(dep, states)
594            dep.per_state_flush_actions(uow, states_for_prop, True)
595
596
597class ProcessState(PostSortRec):
598    def __init__(self, uow, dependency_processor, delete, state):
599        self.dependency_processor = dependency_processor
600        self.delete = delete
601        self.state = state
602
603    def execute_aggregate(self, uow, recs):
604        cls_ = self.__class__
605        dependency_processor = self.dependency_processor
606        delete = self.delete
607        our_recs = [r for r in recs
608                    if r.__class__ is cls_ and
609                    r.dependency_processor is dependency_processor and
610                    r.delete is delete]
611        recs.difference_update(our_recs)
612        states = [self.state] + [r.state for r in our_recs]
613        if delete:
614            dependency_processor.process_deletes(uow, states)
615        else:
616            dependency_processor.process_saves(uow, states)
617
618    def __repr__(self):
619        return "%s(%s, %s, delete=%s)" % (
620            self.__class__.__name__,
621            self.dependency_processor,
622            orm_util.state_str(self.state),
623            self.delete
624        )
625
626
627class SaveUpdateState(PostSortRec):
628    def __init__(self, uow, state, mapper):
629        self.state = state
630        self.mapper = mapper
631
632    def execute_aggregate(self, uow, recs):
633        cls_ = self.__class__
634        mapper = self.mapper
635        our_recs = [r for r in recs
636                    if r.__class__ is cls_ and
637                    r.mapper is mapper]
638        recs.difference_update(our_recs)
639        persistence.save_obj(mapper,
640                             [self.state] +
641                             [r.state for r in our_recs],
642                             uow)
643
644    def __repr__(self):
645        return "%s(%s)" % (
646            self.__class__.__name__,
647            orm_util.state_str(self.state)
648        )
649
650
651class DeleteState(PostSortRec):
652    def __init__(self, uow, state, mapper):
653        self.state = state
654        self.mapper = mapper
655
656    def execute_aggregate(self, uow, recs):
657        cls_ = self.__class__
658        mapper = self.mapper
659        our_recs = [r for r in recs
660                    if r.__class__ is cls_ and
661                    r.mapper is mapper]
662        recs.difference_update(our_recs)
663        states = [self.state] + [r.state for r in our_recs]
664        persistence.delete_obj(mapper,
665                               [s for s in states if uow.states[s][0]],
666                               uow)
667
668    def __repr__(self):
669        return "%s(%s)" % (
670            self.__class__.__name__,
671            orm_util.state_str(self.state)
672        )
673