1#    Licensed under the Apache License, Version 2.0 (the "License"); you may
2#    not use this file except in compliance with the License. You may obtain
3#    a copy of the License at
4#
5#         http://www.apache.org/licenses/LICENSE-2.0
6#
7#    Unless required by applicable law or agreed to in writing, software
8#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10#    License for the specific language governing permissions and limitations
11#    under the License.
12"""Fixtures for writing tests for code using oslo.versionedobjects
13
14.. note::
15
16   This module has several extra dependencies not needed at runtime
17   for production code, and therefore not installed by default. To
18   ensure those dependencies are present for your tests, add
19   ``oslo.versionedobjects[fixtures]`` to your list of test dependencies.
20
21"""
22
23from collections import namedtuple
24from collections import OrderedDict
25import copy
26import datetime
27import inspect
28import logging
29from unittest import mock
30
31import fixtures
32from oslo_utils.secretutils import md5
33from oslo_utils import versionutils as vutils
34
35from oslo_versionedobjects import base
36from oslo_versionedobjects import fields
37
38
39LOG = logging.getLogger(__name__)
40
41
42def compare_obj(test, obj, db_obj, subs=None, allow_missing=None,
43                comparators=None):
44    """Compare a VersionedObject and a dict-like database object.
45
46    This automatically converts TZ-aware datetimes and iterates over
47    the fields of the object.
48
49    :param test: The TestCase doing the comparison
50    :param obj: The VersionedObject to examine
51    :param db_obj: The dict-like database object to use as reference
52    :param subs: A dict of objkey=dbkey field substitutions
53    :param allow_missing: A list of fields that may not be in db_obj
54    :param comparators: Map of comparator functions to use for certain fields
55    """
56
57    subs = subs or {}
58    allow_missing = allow_missing or []
59    comparators = comparators or {}
60
61    for key in obj.fields:
62        db_key = subs.get(key, key)
63
64        # If this is an allow_missing key and it's missing in either obj or
65        # db_obj, just skip it
66        if key in allow_missing:
67            if key not in obj or db_key not in db_obj:
68                continue
69
70        # If the value isn't set on the object, and also isn't set on the
71        # db_obj, we'll skip the value check, unset in both is equal
72        if not obj.obj_attr_is_set(key) and db_key not in db_obj:
73            continue
74        # If it's set on the object and not on the db_obj, they aren't equal
75        elif obj.obj_attr_is_set(key) and db_key not in db_obj:
76            raise AssertionError(("%s (db_key: %s) is set on the object, but "
77                                  "not on the db_obj, so the objects are not "
78                                  "equal")
79                                 % (key, db_key))
80        # If it's set on the db_obj and not the object, they aren't equal
81        elif not obj.obj_attr_is_set(key) and db_key in db_obj:
82            raise AssertionError(("%s (db_key: %s) is set on the db_obj, but "
83                                  "not on the object, so the objects are not "
84                                  "equal")
85                                 % (key, db_key))
86
87        # All of the checks above have safeguarded us, so we know we will
88        # get an obj_val and db_val without issue
89        obj_val = getattr(obj, key)
90        db_val = db_obj[db_key]
91        if isinstance(obj_val, datetime.datetime):
92            obj_val = obj_val.replace(tzinfo=None)
93
94        if isinstance(db_val, datetime.datetime):
95            db_val = obj_val.replace(tzinfo=None)
96
97        if key in comparators:
98            comparator = comparators[key]
99            comparator(db_val, obj_val)
100        else:
101            test.assertEqual(db_val, obj_val)
102
103
104class FakeIndirectionAPI(base.VersionedObjectIndirectionAPI):
105    def __init__(self, serializer=None):
106        super(FakeIndirectionAPI, self).__init__()
107        self._ser = serializer or base.VersionedObjectSerializer()
108
109    def _get_changes(self, orig_obj, new_obj):
110        updates = dict()
111        for name, field in new_obj.fields.items():
112            if not new_obj.obj_attr_is_set(name):
113                continue
114            if (not orig_obj.obj_attr_is_set(name) or
115                    getattr(orig_obj, name) != getattr(new_obj, name)):
116                updates[name] = field.to_primitive(new_obj, name,
117                                                   getattr(new_obj, name))
118        return updates
119
120    def _canonicalize_args(self, context, args, kwargs):
121        args = tuple(
122            [self._ser.deserialize_entity(
123                context, self._ser.serialize_entity(context, arg))
124             for arg in args])
125        kwargs = dict(
126            [(argname, self._ser.deserialize_entity(
127                context, self._ser.serialize_entity(context, arg)))
128             for argname, arg in kwargs.items()])
129        return args, kwargs
130
131    def object_action(self, context, objinst, objmethod, args, kwargs):
132        objinst = self._ser.deserialize_entity(
133            context, self._ser.serialize_entity(
134                context, objinst))
135        objmethod = str(objmethod)
136        args, kwargs = self._canonicalize_args(context, args, kwargs)
137        original = objinst.obj_clone()
138        with mock.patch('oslo_versionedobjects.base.VersionedObject.'
139                        'indirection_api', new=None):
140            result = getattr(objinst, objmethod)(*args, **kwargs)
141        updates = self._get_changes(original, objinst)
142        updates['obj_what_changed'] = objinst.obj_what_changed()
143        return updates, result
144
145    def object_class_action(self, context, objname, objmethod, objver,
146                            args, kwargs):
147        objname = str(objname)
148        objmethod = str(objmethod)
149        objver = str(objver)
150        args, kwargs = self._canonicalize_args(context, args, kwargs)
151        cls = base.VersionedObject.obj_class_from_name(objname, objver)
152        with mock.patch('oslo_versionedobjects.base.VersionedObject.'
153                        'indirection_api', new=None):
154            result = getattr(cls, objmethod)(context, *args, **kwargs)
155        return (base.VersionedObject.obj_from_primitive(
156            result.obj_to_primitive(target_version=objver),
157            context=context)
158            if isinstance(result, base.VersionedObject) else result)
159
160    def object_class_action_versions(self, context, objname, objmethod,
161                                     object_versions, args, kwargs):
162        objname = str(objname)
163        objmethod = str(objmethod)
164        object_versions = {str(o): str(v) for o, v in object_versions.items()}
165        args, kwargs = self._canonicalize_args(context, args, kwargs)
166        objver = object_versions[objname]
167        cls = base.VersionedObject.obj_class_from_name(objname, objver)
168        with mock.patch('oslo_versionedobjects.base.VersionedObject.'
169                        'indirection_api', new=None):
170            result = getattr(cls, objmethod)(context, *args, **kwargs)
171        return (base.VersionedObject.obj_from_primitive(
172            result.obj_to_primitive(target_version=objver),
173            context=context)
174            if isinstance(result, base.VersionedObject) else result)
175
176    def object_backport(self, context, objinst, target_version):
177        raise Exception('not supported')
178
179
180class IndirectionFixture(fixtures.Fixture):
181    def __init__(self, indirection_api=None):
182        self.indirection_api = indirection_api or FakeIndirectionAPI()
183
184    def setUp(self):
185        super(IndirectionFixture, self).setUp()
186        self.useFixture(fixtures.MonkeyPatch(
187            'oslo_versionedobjects.base.VersionedObject.indirection_api',
188            self.indirection_api))
189
190
191class ObjectHashMismatch(Exception):
192    def __init__(self, expected, actual):
193        self.expected = expected
194        self.actual = actual
195
196    def __str__(self):
197        return 'Hashes have changed for %s' % (
198            ','.join(set(self.expected.keys() + self.actual.keys())))
199
200
201CompatArgSpec = namedtuple(
202    'ArgSpec', ('args', 'varargs', 'keywords', 'defaults'))
203
204
205def get_method_spec(method):
206    """Get a stable and compatible method spec.
207
208    Newer features in Python3 (kw-only arguments and annotations) are
209    not supported or representable with inspect.getargspec() but many
210    object hashes are already recorded using that method. This attempts
211    to return something compatible with getargspec() when possible (i.e.
212    when those features are not used), and otherwise just returns the
213    newer getfullargspec() representation.
214    """
215    fullspec = inspect.getfullargspec(method)
216    if any([fullspec.kwonlyargs, fullspec.kwonlydefaults,
217            fullspec.annotations]):
218        # Method uses newer-than-getargspec() features, so return the
219        # newer full spec
220        return fullspec
221    else:
222        return CompatArgSpec(fullspec.args, fullspec.varargs,
223                             fullspec.varkw, fullspec.defaults)
224
225
226class ObjectVersionChecker(object):
227    def __init__(self, obj_classes=base.VersionedObjectRegistry.obj_classes()):
228        self.obj_classes = obj_classes
229
230    def _find_remotable_method(self, cls, thing, parent_was_remotable=False):
231        """Follow a chain of remotable things down to the original function."""
232        if isinstance(thing, classmethod):
233            return self._find_remotable_method(cls, thing.__get__(None, cls))
234        elif (inspect.ismethod(thing) or
235              inspect.isfunction(thing)) and hasattr(thing, 'remotable'):
236            return self._find_remotable_method(cls, thing.original_fn,
237                                               parent_was_remotable=True)
238        elif parent_was_remotable:
239            # We must be the first non-remotable thing underneath a stack of
240            # remotable things (i.e. the actual implementation method)
241            return thing
242        else:
243            # This means the top-level thing never hit a remotable layer
244            return None
245
246    def _get_fingerprint(self, obj_name, extra_data_func=None):
247        obj_class = self.obj_classes[obj_name][0]
248        obj_fields = list(obj_class.fields.items())
249        obj_fields.sort()
250        methods = []
251        for name in dir(obj_class):
252            thing = getattr(obj_class, name)
253            if inspect.ismethod(thing) or inspect.isfunction(thing) \
254               or isinstance(thing, classmethod):
255                method = self._find_remotable_method(obj_class, thing)
256                if method:
257                    methods.append((name, get_method_spec(method)))
258        methods.sort()
259        # NOTE(danms): Things that need a version bump are any fields
260        # and their types, or the signatures of any remotable methods.
261        # Of course, these are just the mechanical changes we can detect,
262        # but many other things may require a version bump (method behavior
263        # and return value changes, for example).
264        if hasattr(obj_class, 'child_versions'):
265            relevant_data = (obj_fields, methods,
266                             OrderedDict(
267                                 sorted(obj_class.child_versions.items())))
268        else:
269            relevant_data = (obj_fields, methods)
270
271        if extra_data_func:
272            relevant_data += extra_data_func(obj_class)
273
274        fingerprint = '%s-%s' % (obj_class.VERSION, md5(
275            bytes(repr(relevant_data).encode()),
276            usedforsecurity=False).hexdigest())
277        return fingerprint
278
279    def get_hashes(self, extra_data_func=None):
280        """Return a dict of computed object hashes.
281
282        :param extra_data_func: a function that is given the object class
283                                which gathers more relevant data about the
284                                class that is needed in versioning. Returns
285                                a tuple containing the extra data bits.
286        """
287
288        fingerprints = {}
289        for obj_name in sorted(self.obj_classes):
290            fingerprints[obj_name] = self._get_fingerprint(
291                obj_name, extra_data_func=extra_data_func)
292        return fingerprints
293
294    def test_hashes(self, expected_hashes, extra_data_func=None):
295        fingerprints = self.get_hashes(extra_data_func=extra_data_func)
296
297        stored = set(expected_hashes.items())
298        computed = set(fingerprints.items())
299        changed = stored.symmetric_difference(computed)
300        expected = {}
301        actual = {}
302        for name, hash in changed:
303            expected[name] = expected_hashes.get(name)
304            actual[name] = fingerprints.get(name)
305
306        return expected, actual
307
308    def _get_dependencies(self, tree, obj_class):
309        obj_name = obj_class.obj_name()
310        if obj_name in tree:
311            return
312
313        for name, field in obj_class.fields.items():
314            if isinstance(field._type, fields.Object):
315                sub_obj_name = field._type._obj_name
316                sub_obj_class = self.obj_classes[sub_obj_name][0]
317                self._get_dependencies(tree, sub_obj_class)
318                tree.setdefault(obj_name, {})
319                tree[obj_name][sub_obj_name] = sub_obj_class.VERSION
320
321    def get_dependency_tree(self):
322        tree = {}
323        for obj_name in self.obj_classes.keys():
324            self._get_dependencies(tree, self.obj_classes[obj_name][0])
325        return tree
326
327    def test_relationships(self, expected_tree):
328        actual_tree = self.get_dependency_tree()
329
330        stored = set([(x, str(y)) for x, y in expected_tree.items()])
331        computed = set([(x, str(y)) for x, y in actual_tree.items()])
332        changed = stored.symmetric_difference(computed)
333        expected = {}
334        actual = {}
335        for name, deps in changed:
336            expected[name] = expected_tree.get(name)
337            actual[name] = actual_tree.get(name)
338
339        return expected, actual
340
341    def _test_object_compatibility(self, obj_class, manifest=None,
342                                   init_args=None, init_kwargs=None):
343        init_args = init_args or []
344        init_kwargs = init_kwargs or {}
345        version = vutils.convert_version_to_tuple(obj_class.VERSION)
346        kwargs = {'version_manifest': manifest} if manifest else {}
347        for n in range(version[1] + 1):
348            test_version = '%d.%d' % (version[0], n)
349            # Run the test with OS_DEBUG=True to see this.
350            LOG.debug('testing obj: %s version: %s' %
351                      (obj_class.obj_name(), test_version))
352            kwargs['target_version'] = test_version
353            obj_class(*init_args, **init_kwargs).obj_to_primitive(**kwargs)
354
355    def test_compatibility_routines(self, use_manifest=False, init_args=None,
356                                    init_kwargs=None):
357        """Test obj_make_compatible() on all object classes.
358
359        :param use_manifest: a boolean that determines if the version
360                             manifest should be passed to obj_make_compatible
361        :param init_args: a dictionary of the format {obj_class: [arg1, arg2]}
362                          that will be used to pass arguments to init on the
363                          given obj_class. If no args are needed, the
364                          obj_class does not need to be added to the dict
365        :param init_kwargs: a dictionary of the format
366                            {obj_class: {'kwarg1': val1}} that will be used to
367                            pass kwargs to init on the given obj_class. If no
368                            kwargs are needed, the obj_class does not need to
369                            be added to the dict
370        """
371        # Iterate all object classes and verify that we can run
372        # obj_make_compatible with every older version than current.
373        # This doesn't actually test the data conversions, but it at least
374        # makes sure the method doesn't blow up on something basic like
375        # expecting the wrong version format.
376        init_args = init_args or {}
377        init_kwargs = init_kwargs or {}
378        for obj_name in self.obj_classes:
379            obj_classes = self.obj_classes[obj_name]
380            if use_manifest:
381                manifest = base.obj_tree_get_versions(obj_name)
382            else:
383                manifest = None
384
385            for obj_class in obj_classes:
386                args_for_init = init_args.get(obj_class, [])
387                kwargs_for_init = init_kwargs.get(obj_class, {})
388                self._test_object_compatibility(obj_class, manifest=manifest,
389                                                init_args=args_for_init,
390                                                init_kwargs=kwargs_for_init)
391
392    def _test_relationships_in_order(self, obj_class):
393        for field, versions in obj_class.obj_relationships.items():
394            last_my_version = (0, 0)
395            last_child_version = (0, 0)
396            for my_version, child_version in versions:
397                _my_version = vutils.convert_version_to_tuple(my_version)
398                _ch_version = vutils.convert_version_to_tuple(child_version)
399                if not (last_my_version < _my_version and
400                        last_child_version <= _ch_version):
401                    raise AssertionError(('Object %s relationship %s->%s for '
402                                          'field %s is out of order') % (
403                                              obj_class.obj_name(),
404                                              my_version, child_version,
405                                              field))
406                last_my_version = _my_version
407                last_child_version = _ch_version
408
409    def test_relationships_in_order(self):
410        # Iterate all object classes and verify that we can run
411        # obj_make_compatible with every older version than current.
412        # This doesn't actually test the data conversions, but it at least
413        # makes sure the method doesn't blow up on something basic like
414        # expecting the wrong version format.
415        for obj_name in self.obj_classes:
416            obj_classes = self.obj_classes[obj_name]
417            for obj_class in obj_classes:
418                self._test_relationships_in_order(obj_class)
419
420
421class VersionedObjectRegistryFixture(fixtures.Fixture):
422    """Use a VersionedObjectRegistry as a temp registry pattern fixture.
423
424    The pattern solution is to backup the object registry, register
425    a class locally, and then restore the original registry. This could be
426    used for test objects that do not need to be registered permanently but
427    will have calls which lookup registration.
428    """
429
430    def setUp(self):
431        super(VersionedObjectRegistryFixture, self).setUp()
432        self._base_test_obj_backup = copy.deepcopy(
433            base.VersionedObjectRegistry._registry._obj_classes)
434        self.addCleanup(self._restore_obj_registry)
435
436    @staticmethod
437    def register(cls_name):
438        base.VersionedObjectRegistry.register(cls_name)
439
440    def _restore_obj_registry(self):
441        base.VersionedObjectRegistry._registry._obj_classes = \
442            self._base_test_obj_backup
443
444
445class StableObjectJsonFixture(fixtures.Fixture):
446    """Fixture that makes sure we get stable JSON object representations.
447
448    Since objects contain things like set(), which can't be converted to
449    JSON, we have some situations where the representation isn't fully
450    deterministic. This doesn't matter at all at runtime, but does to
451    unit tests that try to assert things at a low level.
452
453    This fixture mocks the obj_to_primitive() call and makes sure to
454    sort the list of changed fields (which came from a set) before
455    returning it to the caller.
456    """
457    def __init__(self):
458        self._original_otp = base.VersionedObject.obj_to_primitive
459
460    def setUp(self):
461        super(StableObjectJsonFixture, self).setUp()
462
463        def _doit(obj, *args, **kwargs):
464            result = self._original_otp(obj, *args, **kwargs)
465            changes_key = obj._obj_primitive_key('changes')
466            if changes_key in result:
467                result[changes_key].sort()
468            return result
469
470        self.useFixture(fixtures.MonkeyPatch(
471            'oslo_versionedobjects.base.VersionedObject.obj_to_primitive',
472            _doit))
473