1from collections import namedtuple, defaultdict
2from contextlib import contextmanager
3from functools import wraps
4from threading import local
5from django.apps import apps
6from django.core import serializers
7from django.core.exceptions import ObjectDoesNotExist
8from django.db import models, transaction, router
9from django.db.models.query import QuerySet
10from django.db.models.signals import post_save, m2m_changed
11from django.utils.encoding import force_str
12from django.utils import timezone
13from reversion.errors import RevisionManagementError, RegistrationError
14from reversion.signals import pre_revision_commit, post_revision_commit
15
16
17_VersionOptions = namedtuple("VersionOptions", (
18    "fields",
19    "follow",
20    "format",
21    "for_concrete_model",
22    "ignore_duplicates",
23))
24
25
26_StackFrame = namedtuple("StackFrame", (
27    "manage_manually",
28    "user",
29    "comment",
30    "date_created",
31    "db_versions",
32    "meta",
33))
34
35
36class _Local(local):
37
38    def __init__(self):
39        self.stack = ()
40
41
42_local = _Local()
43
44
45def is_active():
46    return bool(_local.stack)
47
48
49def _current_frame():
50    if not is_active():
51        raise RevisionManagementError("There is no active revision for this thread")
52    return _local.stack[-1]
53
54
55def _copy_db_versions(db_versions):
56    return {
57        db: versions.copy()
58        for db, versions
59        in db_versions.items()
60    }
61
62
63def _push_frame(manage_manually, using):
64    if is_active():
65        current_frame = _current_frame()
66        db_versions = _copy_db_versions(current_frame.db_versions)
67        db_versions.setdefault(using, {})
68        stack_frame = current_frame._replace(
69            manage_manually=manage_manually,
70            db_versions=db_versions,
71        )
72    else:
73        stack_frame = _StackFrame(
74            manage_manually=manage_manually,
75            user=None,
76            comment="",
77            date_created=timezone.now(),
78            db_versions={using: {}},
79            meta=(),
80        )
81    _local.stack += (stack_frame,)
82
83
84def _update_frame(**kwargs):
85    _local.stack = _local.stack[:-1] + (_current_frame()._replace(**kwargs),)
86
87
88def _pop_frame():
89    prev_frame = _current_frame()
90    _local.stack = _local.stack[:-1]
91    if is_active():
92        current_frame = _current_frame()
93        db_versions = {
94            db: prev_frame.db_versions[db]
95            for db
96            in current_frame.db_versions.keys()
97        }
98        _update_frame(
99            user=prev_frame.user,
100            comment=prev_frame.comment,
101            date_created=prev_frame.date_created,
102            db_versions=db_versions,
103            meta=prev_frame.meta,
104        )
105
106
107def is_manage_manually():
108    return _current_frame().manage_manually
109
110
111def set_user(user):
112    _update_frame(user=user)
113
114
115def get_user():
116    return _current_frame().user
117
118
119def set_comment(comment):
120    _update_frame(comment=comment)
121
122
123def get_comment():
124    return _current_frame().comment
125
126
127def set_date_created(date_created):
128    _update_frame(date_created=date_created)
129
130
131def get_date_created():
132    return _current_frame().date_created
133
134
135def add_meta(model, **values):
136    _update_frame(meta=_current_frame().meta + ((model, values),))
137
138
139def _follow_relations(obj):
140    version_options = _get_options(obj.__class__)
141    for follow_name in version_options.follow:
142        try:
143            follow_obj = getattr(obj, follow_name)
144        except ObjectDoesNotExist:
145            continue
146        if isinstance(follow_obj, models.Model):
147            yield follow_obj
148        elif isinstance(follow_obj, (models.Manager, QuerySet)):
149            for follow_obj_instance in follow_obj.all():
150                yield follow_obj_instance
151        elif follow_obj is not None:
152            raise RegistrationError("{name}.{follow_name} should be a Model or QuerySet".format(
153                name=obj.__class__.__name__,
154                follow_name=follow_name,
155            ))
156
157
158def _follow_relations_recursive(obj):
159    def do_follow(obj):
160        if obj not in relations:
161            relations.add(obj)
162            for related in _follow_relations(obj):
163                do_follow(related)
164    relations = set()
165    do_follow(obj)
166    return relations
167
168
169def _add_to_revision(obj, using, model_db, explicit):
170    from reversion.models import Version
171    # Exit early if the object is not fully-formed.
172    if obj.pk is None:
173        return
174    version_options = _get_options(obj.__class__)
175    content_type = _get_content_type(obj.__class__, using)
176    object_id = force_str(obj.pk)
177    version_key = (content_type, object_id)
178    # If the obj is already in the revision, stop now.
179    db_versions = _current_frame().db_versions
180    versions = db_versions[using]
181    if version_key in versions and not explicit:
182        return
183    # Get the version data.
184    version = Version(
185        content_type=content_type,
186        object_id=object_id,
187        db=model_db,
188        format=version_options.format,
189        serialized_data=serializers.serialize(
190            version_options.format,
191            (obj,),
192            fields=version_options.fields,
193        ),
194        object_repr=force_str(obj),
195    )
196    # If the version is a duplicate, stop now.
197    if version_options.ignore_duplicates and explicit:
198        previous_version = Version.objects.using(using).get_for_object(obj, model_db=model_db).first()
199        if previous_version and previous_version._local_field_dict == version._local_field_dict:
200            return
201    # Store the version.
202    db_versions = _copy_db_versions(db_versions)
203    db_versions[using][version_key] = version
204    _update_frame(db_versions=db_versions)
205    # Follow relations.
206    for follow_obj in _follow_relations(obj):
207        _add_to_revision(follow_obj, using, model_db, False)
208
209
210def add_to_revision(obj, model_db=None):
211    model_db = model_db or router.db_for_write(obj.__class__, instance=obj)
212    for db in _current_frame().db_versions.keys():
213        _add_to_revision(obj, db, model_db, True)
214
215
216def _save_revision(versions, user=None, comment="", meta=(), date_created=None, using=None):
217    from reversion.models import Revision
218    # Only save versions that exist in the database.
219    # Use _base_manager so we don't have problems when _default_manager is overriden
220    model_db_pks = defaultdict(lambda: defaultdict(set))
221    for version in versions:
222        model_db_pks[version._model][version.db].add(version.object_id)
223    model_db_existing_pks = {
224        model: {
225            db: frozenset(map(
226                force_str,
227                model._base_manager.using(db).filter(pk__in=pks).values_list("pk", flat=True),
228            ))
229            for db, pks in db_pks.items()
230        }
231        for model, db_pks in model_db_pks.items()
232    }
233    versions = [
234        version for version in versions
235        if version.object_id in model_db_existing_pks[version._model][version.db]
236    ]
237    # Bail early if there are no objects to save.
238    if not versions:
239        return
240    # Save a new revision.
241    revision = Revision(
242        date_created=date_created,
243        user=user,
244        comment=comment,
245    )
246    # Send the pre_revision_commit signal.
247    pre_revision_commit.send(
248        sender=create_revision,
249        revision=revision,
250        versions=versions,
251    )
252    # Save the revision.
253    revision.save(using=using)
254    # Save version models.
255    for version in versions:
256        version.revision = revision
257        version.save(using=using)
258    # Save the meta information.
259    for meta_model, meta_fields in meta:
260        meta_model._base_manager.db_manager(using=using).create(
261            revision=revision,
262            **meta_fields
263        )
264    # Send the post_revision_commit signal.
265    post_revision_commit.send(
266        sender=create_revision,
267        revision=revision,
268        versions=versions,
269    )
270
271
272@contextmanager
273def _dummy_context():
274    yield
275
276
277@contextmanager
278def _create_revision_context(manage_manually, using, atomic):
279    context = transaction.atomic(using=using) if atomic else _dummy_context()
280    with context:
281        _push_frame(manage_manually, using)
282        try:
283            yield
284            # Only save for a db if that's the last stack frame for that db.
285            if not any(using in frame.db_versions for frame in _local.stack[:-1]):
286                current_frame = _current_frame()
287                _save_revision(
288                    versions=current_frame.db_versions[using].values(),
289                    user=current_frame.user,
290                    comment=current_frame.comment,
291                    meta=current_frame.meta,
292                    date_created=current_frame.date_created,
293                    using=using,
294                )
295        finally:
296            _pop_frame()
297
298
299def create_revision(manage_manually=False, using=None, atomic=True):
300    from reversion.models import Revision
301    using = using or router.db_for_write(Revision)
302    return _ContextWrapper(_create_revision_context, (manage_manually, using, atomic))
303
304
305class _ContextWrapper(object):
306
307    def __init__(self, func, args):
308        self._func = func
309        self._args = args
310        self._context = func(*args)
311
312    def __enter__(self):
313        return self._context.__enter__()
314
315    def __exit__(self, exc_type, exc_value, traceback):
316        return self._context.__exit__(exc_type, exc_value, traceback)
317
318    def __call__(self, func):
319        @wraps(func)
320        def do_revision_context(*args, **kwargs):
321            with self._func(*self._args):
322                return func(*args, **kwargs)
323        return do_revision_context
324
325
326def _post_save_receiver(sender, instance, using, **kwargs):
327    if is_registered(sender) and is_active() and not is_manage_manually():
328        add_to_revision(instance, model_db=using)
329
330
331def _m2m_changed_receiver(instance, using, action, model, reverse, **kwargs):
332    if action.startswith("post_") and not reverse:
333        if is_registered(instance) and is_active() and not is_manage_manually():
334            add_to_revision(instance, model_db=using)
335
336
337def _get_registration_key(model):
338    return (model._meta.app_label, model._meta.model_name)
339
340
341_registered_models = {}
342
343
344def is_registered(model):
345    return _get_registration_key(model) in _registered_models
346
347
348def get_registered_models():
349    return (apps.get_model(*key) for key in _registered_models.keys())
350
351
352def _get_senders_and_signals(model):
353    yield model, post_save, _post_save_receiver
354    opts = model._meta.concrete_model._meta
355    for field in opts.local_many_to_many:
356        m2m_model = field.remote_field.through
357        if isinstance(m2m_model, str):
358            if "." not in m2m_model:
359                m2m_model = "{app_label}.{m2m_model}".format(
360                    app_label=opts.app_label,
361                    m2m_model=m2m_model
362                )
363        yield m2m_model, m2m_changed, _m2m_changed_receiver
364
365
366def register(model=None, fields=None, exclude=(), follow=(), format="json",
367             for_concrete_model=True, ignore_duplicates=False):
368    def register(model):
369        # Prevent multiple registration.
370        if is_registered(model):
371            raise RegistrationError("{model} has already been registered with django-reversion".format(
372                model=model,
373            ))
374        # Parse fields.
375        opts = model._meta.concrete_model._meta
376        version_options = _VersionOptions(
377            fields=tuple(
378                field_name
379                for field_name
380                in ([
381                    field.name
382                    for field
383                    in opts.local_fields + opts.local_many_to_many
384                ] if fields is None else fields)
385                if field_name not in exclude
386            ),
387            follow=tuple(follow),
388            format=format,
389            for_concrete_model=for_concrete_model,
390            ignore_duplicates=ignore_duplicates,
391        )
392        # Register the model.
393        _registered_models[_get_registration_key(model)] = version_options
394        # Connect signals.
395        for sender, signal, signal_receiver in _get_senders_and_signals(model):
396            signal.connect(signal_receiver, sender=sender)
397        # All done!
398        return model
399    # Return a class decorator if model is not given
400    if model is None:
401        return register
402    # Register the model.
403    return register(model)
404
405
406def _assert_registered(model):
407    if not is_registered(model):
408        raise RegistrationError("{model} has not been registered with django-reversion".format(
409            model=model,
410        ))
411
412
413def _get_options(model):
414    _assert_registered(model)
415    return _registered_models[_get_registration_key(model)]
416
417
418def unregister(model):
419    _assert_registered(model)
420    del _registered_models[_get_registration_key(model)]
421    # Disconnect signals.
422    for sender, signal, signal_receiver in _get_senders_and_signals(model):
423        signal.disconnect(signal_receiver, sender=sender)
424
425
426def _get_content_type(model, using):
427    from django.contrib.contenttypes.models import ContentType
428    version_options = _get_options(model)
429    return ContentType.objects.db_manager(using).get_for_model(
430        model,
431        for_concrete_model=version_options.for_concrete_model,
432    )
433