1from collections import namedtuple, defaultdict 2from contextlib import contextmanager 3from functools import wraps 4from threading import local 5from django.apps import apps 6from django.core import serializers 7from django.core.exceptions import ObjectDoesNotExist 8from django.db import models, transaction, router 9from django.db.models.query import QuerySet 10from django.db.models.signals import post_save, m2m_changed 11from django.utils.encoding import force_str 12from django.utils import timezone 13from reversion.errors import RevisionManagementError, RegistrationError 14from reversion.signals import pre_revision_commit, post_revision_commit 15 16 17_VersionOptions = namedtuple("VersionOptions", ( 18 "fields", 19 "follow", 20 "format", 21 "for_concrete_model", 22 "ignore_duplicates", 23)) 24 25 26_StackFrame = namedtuple("StackFrame", ( 27 "manage_manually", 28 "user", 29 "comment", 30 "date_created", 31 "db_versions", 32 "meta", 33)) 34 35 36class _Local(local): 37 38 def __init__(self): 39 self.stack = () 40 41 42_local = _Local() 43 44 45def is_active(): 46 return bool(_local.stack) 47 48 49def _current_frame(): 50 if not is_active(): 51 raise RevisionManagementError("There is no active revision for this thread") 52 return _local.stack[-1] 53 54 55def _copy_db_versions(db_versions): 56 return { 57 db: versions.copy() 58 for db, versions 59 in db_versions.items() 60 } 61 62 63def _push_frame(manage_manually, using): 64 if is_active(): 65 current_frame = _current_frame() 66 db_versions = _copy_db_versions(current_frame.db_versions) 67 db_versions.setdefault(using, {}) 68 stack_frame = current_frame._replace( 69 manage_manually=manage_manually, 70 db_versions=db_versions, 71 ) 72 else: 73 stack_frame = _StackFrame( 74 manage_manually=manage_manually, 75 user=None, 76 comment="", 77 date_created=timezone.now(), 78 db_versions={using: {}}, 79 meta=(), 80 ) 81 _local.stack += (stack_frame,) 82 83 84def _update_frame(**kwargs): 85 _local.stack = _local.stack[:-1] + (_current_frame()._replace(**kwargs),) 86 87 88def _pop_frame(): 89 prev_frame = _current_frame() 90 _local.stack = _local.stack[:-1] 91 if is_active(): 92 current_frame = _current_frame() 93 db_versions = { 94 db: prev_frame.db_versions[db] 95 for db 96 in current_frame.db_versions.keys() 97 } 98 _update_frame( 99 user=prev_frame.user, 100 comment=prev_frame.comment, 101 date_created=prev_frame.date_created, 102 db_versions=db_versions, 103 meta=prev_frame.meta, 104 ) 105 106 107def is_manage_manually(): 108 return _current_frame().manage_manually 109 110 111def set_user(user): 112 _update_frame(user=user) 113 114 115def get_user(): 116 return _current_frame().user 117 118 119def set_comment(comment): 120 _update_frame(comment=comment) 121 122 123def get_comment(): 124 return _current_frame().comment 125 126 127def set_date_created(date_created): 128 _update_frame(date_created=date_created) 129 130 131def get_date_created(): 132 return _current_frame().date_created 133 134 135def add_meta(model, **values): 136 _update_frame(meta=_current_frame().meta + ((model, values),)) 137 138 139def _follow_relations(obj): 140 version_options = _get_options(obj.__class__) 141 for follow_name in version_options.follow: 142 try: 143 follow_obj = getattr(obj, follow_name) 144 except ObjectDoesNotExist: 145 continue 146 if isinstance(follow_obj, models.Model): 147 yield follow_obj 148 elif isinstance(follow_obj, (models.Manager, QuerySet)): 149 for follow_obj_instance in follow_obj.all(): 150 yield follow_obj_instance 151 elif follow_obj is not None: 152 raise RegistrationError("{name}.{follow_name} should be a Model or QuerySet".format( 153 name=obj.__class__.__name__, 154 follow_name=follow_name, 155 )) 156 157 158def _follow_relations_recursive(obj): 159 def do_follow(obj): 160 if obj not in relations: 161 relations.add(obj) 162 for related in _follow_relations(obj): 163 do_follow(related) 164 relations = set() 165 do_follow(obj) 166 return relations 167 168 169def _add_to_revision(obj, using, model_db, explicit): 170 from reversion.models import Version 171 # Exit early if the object is not fully-formed. 172 if obj.pk is None: 173 return 174 version_options = _get_options(obj.__class__) 175 content_type = _get_content_type(obj.__class__, using) 176 object_id = force_str(obj.pk) 177 version_key = (content_type, object_id) 178 # If the obj is already in the revision, stop now. 179 db_versions = _current_frame().db_versions 180 versions = db_versions[using] 181 if version_key in versions and not explicit: 182 return 183 # Get the version data. 184 version = Version( 185 content_type=content_type, 186 object_id=object_id, 187 db=model_db, 188 format=version_options.format, 189 serialized_data=serializers.serialize( 190 version_options.format, 191 (obj,), 192 fields=version_options.fields, 193 ), 194 object_repr=force_str(obj), 195 ) 196 # If the version is a duplicate, stop now. 197 if version_options.ignore_duplicates and explicit: 198 previous_version = Version.objects.using(using).get_for_object(obj, model_db=model_db).first() 199 if previous_version and previous_version._local_field_dict == version._local_field_dict: 200 return 201 # Store the version. 202 db_versions = _copy_db_versions(db_versions) 203 db_versions[using][version_key] = version 204 _update_frame(db_versions=db_versions) 205 # Follow relations. 206 for follow_obj in _follow_relations(obj): 207 _add_to_revision(follow_obj, using, model_db, False) 208 209 210def add_to_revision(obj, model_db=None): 211 model_db = model_db or router.db_for_write(obj.__class__, instance=obj) 212 for db in _current_frame().db_versions.keys(): 213 _add_to_revision(obj, db, model_db, True) 214 215 216def _save_revision(versions, user=None, comment="", meta=(), date_created=None, using=None): 217 from reversion.models import Revision 218 # Only save versions that exist in the database. 219 # Use _base_manager so we don't have problems when _default_manager is overriden 220 model_db_pks = defaultdict(lambda: defaultdict(set)) 221 for version in versions: 222 model_db_pks[version._model][version.db].add(version.object_id) 223 model_db_existing_pks = { 224 model: { 225 db: frozenset(map( 226 force_str, 227 model._base_manager.using(db).filter(pk__in=pks).values_list("pk", flat=True), 228 )) 229 for db, pks in db_pks.items() 230 } 231 for model, db_pks in model_db_pks.items() 232 } 233 versions = [ 234 version for version in versions 235 if version.object_id in model_db_existing_pks[version._model][version.db] 236 ] 237 # Bail early if there are no objects to save. 238 if not versions: 239 return 240 # Save a new revision. 241 revision = Revision( 242 date_created=date_created, 243 user=user, 244 comment=comment, 245 ) 246 # Send the pre_revision_commit signal. 247 pre_revision_commit.send( 248 sender=create_revision, 249 revision=revision, 250 versions=versions, 251 ) 252 # Save the revision. 253 revision.save(using=using) 254 # Save version models. 255 for version in versions: 256 version.revision = revision 257 version.save(using=using) 258 # Save the meta information. 259 for meta_model, meta_fields in meta: 260 meta_model._base_manager.db_manager(using=using).create( 261 revision=revision, 262 **meta_fields 263 ) 264 # Send the post_revision_commit signal. 265 post_revision_commit.send( 266 sender=create_revision, 267 revision=revision, 268 versions=versions, 269 ) 270 271 272@contextmanager 273def _dummy_context(): 274 yield 275 276 277@contextmanager 278def _create_revision_context(manage_manually, using, atomic): 279 context = transaction.atomic(using=using) if atomic else _dummy_context() 280 with context: 281 _push_frame(manage_manually, using) 282 try: 283 yield 284 # Only save for a db if that's the last stack frame for that db. 285 if not any(using in frame.db_versions for frame in _local.stack[:-1]): 286 current_frame = _current_frame() 287 _save_revision( 288 versions=current_frame.db_versions[using].values(), 289 user=current_frame.user, 290 comment=current_frame.comment, 291 meta=current_frame.meta, 292 date_created=current_frame.date_created, 293 using=using, 294 ) 295 finally: 296 _pop_frame() 297 298 299def create_revision(manage_manually=False, using=None, atomic=True): 300 from reversion.models import Revision 301 using = using or router.db_for_write(Revision) 302 return _ContextWrapper(_create_revision_context, (manage_manually, using, atomic)) 303 304 305class _ContextWrapper(object): 306 307 def __init__(self, func, args): 308 self._func = func 309 self._args = args 310 self._context = func(*args) 311 312 def __enter__(self): 313 return self._context.__enter__() 314 315 def __exit__(self, exc_type, exc_value, traceback): 316 return self._context.__exit__(exc_type, exc_value, traceback) 317 318 def __call__(self, func): 319 @wraps(func) 320 def do_revision_context(*args, **kwargs): 321 with self._func(*self._args): 322 return func(*args, **kwargs) 323 return do_revision_context 324 325 326def _post_save_receiver(sender, instance, using, **kwargs): 327 if is_registered(sender) and is_active() and not is_manage_manually(): 328 add_to_revision(instance, model_db=using) 329 330 331def _m2m_changed_receiver(instance, using, action, model, reverse, **kwargs): 332 if action.startswith("post_") and not reverse: 333 if is_registered(instance) and is_active() and not is_manage_manually(): 334 add_to_revision(instance, model_db=using) 335 336 337def _get_registration_key(model): 338 return (model._meta.app_label, model._meta.model_name) 339 340 341_registered_models = {} 342 343 344def is_registered(model): 345 return _get_registration_key(model) in _registered_models 346 347 348def get_registered_models(): 349 return (apps.get_model(*key) for key in _registered_models.keys()) 350 351 352def _get_senders_and_signals(model): 353 yield model, post_save, _post_save_receiver 354 opts = model._meta.concrete_model._meta 355 for field in opts.local_many_to_many: 356 m2m_model = field.remote_field.through 357 if isinstance(m2m_model, str): 358 if "." not in m2m_model: 359 m2m_model = "{app_label}.{m2m_model}".format( 360 app_label=opts.app_label, 361 m2m_model=m2m_model 362 ) 363 yield m2m_model, m2m_changed, _m2m_changed_receiver 364 365 366def register(model=None, fields=None, exclude=(), follow=(), format="json", 367 for_concrete_model=True, ignore_duplicates=False): 368 def register(model): 369 # Prevent multiple registration. 370 if is_registered(model): 371 raise RegistrationError("{model} has already been registered with django-reversion".format( 372 model=model, 373 )) 374 # Parse fields. 375 opts = model._meta.concrete_model._meta 376 version_options = _VersionOptions( 377 fields=tuple( 378 field_name 379 for field_name 380 in ([ 381 field.name 382 for field 383 in opts.local_fields + opts.local_many_to_many 384 ] if fields is None else fields) 385 if field_name not in exclude 386 ), 387 follow=tuple(follow), 388 format=format, 389 for_concrete_model=for_concrete_model, 390 ignore_duplicates=ignore_duplicates, 391 ) 392 # Register the model. 393 _registered_models[_get_registration_key(model)] = version_options 394 # Connect signals. 395 for sender, signal, signal_receiver in _get_senders_and_signals(model): 396 signal.connect(signal_receiver, sender=sender) 397 # All done! 398 return model 399 # Return a class decorator if model is not given 400 if model is None: 401 return register 402 # Register the model. 403 return register(model) 404 405 406def _assert_registered(model): 407 if not is_registered(model): 408 raise RegistrationError("{model} has not been registered with django-reversion".format( 409 model=model, 410 )) 411 412 413def _get_options(model): 414 _assert_registered(model) 415 return _registered_models[_get_registration_key(model)] 416 417 418def unregister(model): 419 _assert_registered(model) 420 del _registered_models[_get_registration_key(model)] 421 # Disconnect signals. 422 for sender, signal, signal_receiver in _get_senders_and_signals(model): 423 signal.disconnect(signal_receiver, sender=sender) 424 425 426def _get_content_type(model, using): 427 from django.contrib.contenttypes.models import ContentType 428 version_options = _get_options(model) 429 return ContentType.objects.db_manager(using).get_for_model( 430 model, 431 for_concrete_model=version_options.for_concrete_model, 432 ) 433