1# (C) Copyright 2005-2021 Enthought, Inc., Austin, TX
2# All rights reserved.
3#
4# This software is provided without warranty under the terms of the BSD
5# license included in LICENSE.txt and may be redistributed only under
6# the conditions described in the aforementioned license. The license
7# is also available online at http://www.enthought.com/licenses/BSD.txt
8#
9# Thanks for using Enthought open source!
10
11""" Trait assert mixin class to simplify test implementation for Trait
12Classes.
13
14"""
15
16import contextlib
17import threading
18import sys
19import warnings
20
21# Support for 'from traits.testing.unittest_tools import unittest',
22# which was used to make unittest2 available under the name unittest
23# on Python 2.6. We keep the import for now, for backwards compatibility.
24import unittest  # noqa: F401
25
26from traits.api import (
27    Any,
28    Event,
29    HasStrictTraits,
30    Int,
31    List,
32    Str,
33    Property,
34)
35from traits.util.async_trait_wait import wait_for_condition
36
37
38class _AssertTraitChangesContext(object):
39    """ A context manager used to implement the trait change assert methods.
40
41    Notes
42    -----
43    Checking if the provided xname corresponds to valid traits in the class
44    is not implemented yet.
45
46    Parameters
47    ----------
48    obj : HasTraits
49        The HasTraits class instance whose class trait will change.
50
51    xname : str
52        The extended trait name of trait changes to listen to.
53
54    count : int, optional
55        The expected number of times the event should be fired. When None
56        (default value) there is no check for the number of times the
57        change event was fired.
58
59    test_case : TestCase
60        A unittest TestCase where to raise the failureException if
61        necessary.
62
63    Attributes
64    ----------
65    obj : HasTraits
66        The HasTraits class instance whose class trait will change.
67
68    xname : str
69        The extended trait name of trait changes to listen to.
70
71    count : int, optional
72        The expected number of times the event should be fired. When None
73        (default value) there is no check for the number of times the
74        change event was fired.
75
76    events : list of tuples
77        A list with tuple elements containing the arguments of an
78        `on_trait_change` event signature (<object>, <name>, <old>, <new>).
79
80    Raises
81    ------
82    AssertionError :
83          When the desired number of trait changed did not take place or when
84          `count = None` and no trait change took place.
85
86    """
87
88    def __init__(self, obj, xname, count, test_case):
89        self.obj = obj
90        self.xname = xname
91        self.count = count
92        self.event = None
93        self.events = []
94        self.failureException = test_case.failureException
95
96    def _listener(self, obj, name, old, new):
97        """ Dummy trait listener.
98        """
99        self.event = (obj, name, old, new)
100        self.events.append(self.event)
101
102    def __enter__(self):
103        """ Bind the trait listener.
104        """
105        self.obj.on_trait_change(self._listener, self.xname)
106        return self
107
108    def __exit__(self, exc_type, exc_value, tb):
109        """ Remove the trait listener.
110        """
111        if exc_type is not None:
112            return False
113
114        self.obj.on_trait_change(self._listener, self.xname, remove=True)
115        if self.count is not None and len(self.events) != self.count:
116            msg = "Change event for {0} was fired {1} times instead of {2}"
117            items = self.xname, len(self.events), self.count
118            raise self.failureException(msg.format(*items))
119        elif self.count is None and not self.events:
120            msg = "A change event was not fired for: {0}".format(self.xname)
121            raise self.failureException(msg)
122        return False
123
124
125@contextlib.contextmanager
126def reverse_assertion(context, msg):
127    context.__enter__()
128    try:
129        yield context
130    finally:
131        try:
132            context.__exit__(None, None, None)
133        except AssertionError:
134            pass
135        else:
136            raise context.failureException(msg)
137
138
139class _TraitsChangeCollector(HasStrictTraits):
140    """ Class allowing thread-safe recording of events.
141    """
142
143    # The object we're listening to.
144    obj = Any
145
146    # The (possibly extended) trait name(s).
147    trait_name = Str
148
149    # Read-only event count.
150    event_count = Property(Int)
151
152    # Event that's triggered when the event count is updated.
153    event_count_updated = Event
154
155    # Private list of events.
156    events = List(Any)
157
158    # Private lock used to allow access to events by multiple threads
159    # simultaneously.
160    _lock = Any()
161
162    def __init__(self, **traits):
163        if "trait" in traits:
164            value = traits.pop("trait")
165            message = (
166                "The `trait` keyword is deprecated." " please use `trait_name`"
167            )
168            warnings.warn(message, DeprecationWarning, stacklevel=2)
169            traits["trait_name"] = value
170        super().__init__(**traits)
171
172        # We assign the lock eagerly rather than depending on a lazy default,
173        # since we want to be sure that the lock is created (a) only once,
174        # and (b) on the main thread. Similarly for the events list.
175        self._lock = threading.Lock()
176        self.events = []
177
178    def start_collecting(self):
179        self.obj.on_trait_change(self._event_handler, self.trait_name)
180
181    def stop_collecting(self):
182        self.obj.on_trait_change(
183            self._event_handler, self.trait_name, remove=True
184        )
185
186    def _event_handler(self, new):
187        with self._lock:
188            self.events.append(new)
189        self.event_count_updated = True
190
191    def _get_event_count(self):
192        """ Traits property getter.
193
194        Thread-safe access to event count.
195
196        """
197        with self._lock:
198            return len(self.events)
199
200
201class UnittestTools(object):
202    """ Mixin class to augment the unittest.TestCase class with useful trait
203    related assert methods.
204
205    """
206
207    def assertTraitChanges(
208        self, obj, trait, count=None, callableObj=None, *args, **kwargs
209    ):
210        """ Assert an object trait changes a given number of times.
211
212        Assert that the class trait changes exactly `count` times during
213        execution of the provided function.
214
215        This method can also be used in a with statement to assert that
216        a class trait has changed during the execution of the code inside
217        the with statement (similar to the assertRaises method). Please note
218        that in that case the context manager returns itself and the user
219        can introspect the information of:
220
221        - The last event fired by accessing the ``event`` attribute of the
222          returned object.
223
224        - All the fired events by accessing the ``events`` attribute of
225          the return object.
226
227        Note that in the case of chained properties (trait 'foo' depends on
228        'bar', which in turn depends on 'baz'), the order in which the
229        corresponding trait events appear in the ``events`` attribute is
230        not well-defined, and may depend on dictionary ordering.
231
232        **Example**::
233
234            class MyClass(HasTraits):
235                number = Float(2.0)
236
237            my_class = MyClass()
238
239            with self.assertTraitChanges(my_class, 'number', count=1):
240                my_class.number = 3.0
241
242        Parameters
243        ----------
244        obj : HasTraits
245            The HasTraits class instance whose class trait will change.
246
247        trait : str
248            The extended trait name of trait changes to listen to.
249
250        count : int or None, optional
251            The expected number of times the event should be fired. When None
252            (default value) there is no check for the number of times the
253            change event was fired.
254
255        callableObj : callable, optional
256            A callable object that will trigger the expected trait change.
257            When None (default value) a trigger is expected to be called
258            under the context manger returned by this method.
259
260        *args :
261            List of positional arguments for ``callableObj``
262
263        **kwargs :
264            Dict of keyword value pairs to be passed to the ``callableObj``
265
266
267        Returns
268        -------
269        context : context manager or None
270            If ``callableObj`` is None, an assertion context manager is
271            returned, inside of which a trait-change trigger can be invoked.
272            Otherwise, the context is used internally with ``callableObj`` as
273            the trigger, in which case None is returned.
274
275        Notes
276        -----
277        - Checking if the provided ``trait`` corresponds to valid traits in
278          the class is not implemented yet.
279        - Using the functional version of the assert method requires the
280          ``count`` argument to be given even if it is None.
281
282        """
283        context = _AssertTraitChangesContext(obj, trait, count, self)
284        if callableObj is None:
285            return context
286        with context:
287            callableObj(*args, **kwargs)
288
289    def assertTraitDoesNotChange(
290        self, obj, trait, callableObj=None, *args, **kwargs
291    ):
292        """ Assert an object trait does not change.
293
294        Assert that the class trait does not change during
295        execution of the provided function.
296
297        Parameters
298        ----------
299        obj : HasTraits
300            The HasTraits class instance whose class trait will change.
301
302        trait : str
303            The extended trait name of trait changes to listen to.
304
305        callableObj : callable, optional
306            A callable object that should not trigger a change in the
307            passed trait.  When None (default value) a trigger is expected
308            to be called under the context manger returned by this method.
309
310        *args :
311            List of positional arguments for ``callableObj``
312
313        **kwargs :
314            Dict of keyword value pairs to be passed to the ``callableObj``
315
316
317        Returns
318        -------
319        context : context manager or None
320            If ``callableObj`` is None, an assertion context manager is
321            returned, inside of which a trait-change trigger can be invoked.
322            Otherwise, the context is used internally with ``callableObj`` as
323            the trigger, in which case None is returned.
324
325        """
326        msg = "A change event was fired for: {0}".format(trait)
327        context = _AssertTraitChangesContext(obj, trait, None, self)
328        if callableObj is None:
329            return reverse_assertion(context, msg)
330        with reverse_assertion(context, msg):
331            callableObj(*args, **kwargs)
332
333    @contextlib.contextmanager
334    def assertMultiTraitChanges(
335        self, objects, traits_modified, traits_not_modified
336    ):
337        """ Assert that traits on multiple objects do or do not change.
338
339        This combines some of the functionality of `assertTraitChanges` and
340        `assertTraitDoesNotChange`.
341
342        Parameters
343        ----------
344        objects : list of HasTraits
345            The HasTraits class instances whose traits will change.
346
347        traits_modified : list of str
348            The extended trait names of trait expected to change.
349
350        traits_not_modified : list of str
351            The extended trait names of traits not expected to change.
352
353        """
354        with contextlib.ExitStack() as exit_stack:
355            cms = []
356            for obj in objects:
357                for trait in traits_modified:
358                    cms.append(exit_stack.enter_context(
359                        self.assertTraitChanges(obj, trait)))
360                for trait in traits_not_modified:
361                    cms.append(exit_stack.enter_context(
362                        self.assertTraitDoesNotChange(obj, trait)))
363            yield tuple(cms)
364
365    @contextlib.contextmanager
366    def assertTraitChangesAsync(self, obj, trait, count=1, timeout=5.0):
367        """ Assert an object trait eventually changes.
368
369        Context manager used to assert that the given trait changes at
370        least `count` times within the given timeout, as a result of
371        execution of the body of the corresponding with block.
372
373        The trait changes are permitted to occur asynchronously.
374
375        **Example usage**::
376
377            with self.assertTraitChangesAsync(my_object, 'SomeEvent', count=4):
378                <do stuff that should cause my_object.SomeEvent to be
379                fired at least 4 times within the next 5 seconds>
380
381
382        Parameters
383        ----------
384        obj : HasTraits
385            The HasTraits class instance whose class trait will change.
386
387        trait : str
388            The extended trait name of trait changes to listen to.
389
390        count : int, optional
391            The expected number of times the event should be fired.
392
393        timeout : float or None, optional
394            The amount of time in seconds to wait for the specified number
395            of changes. None can be used to indicate no timeout.
396
397        """
398        collector = _TraitsChangeCollector(obj=obj, trait_name=trait)
399
400        # Pass control to body of the with statement.
401        collector.start_collecting()
402        try:
403            yield collector
404
405            # Wait for the expected number of events to arrive.
406            try:
407                wait_for_condition(
408                    condition=lambda obj: obj.event_count >= count,
409                    obj=collector,
410                    trait="event_count_updated",
411                    timeout=timeout,
412                )
413            except RuntimeError:
414                actual_event_count = collector.event_count
415                msg = (
416                    "Expected {0} event on {1} to be fired at least {2} "
417                    "times, but the event was only fired {3} times "
418                    "before timeout ({4} seconds)."
419                ).format(trait, obj, count, actual_event_count, timeout)
420                self.fail(msg)
421
422        finally:
423            collector.stop_collecting()
424
425    def assertEventuallyTrue(self, obj, trait, condition, timeout=5.0):
426        """ Assert that the given condition is eventually true.
427
428        Parameters
429        ----------
430        obj : HasTraits
431            The HasTraits class instance whose traits will change.
432
433        trait : str
434            The extended trait name of trait changes to listen to.
435
436        condition : callable
437            A function that will be called when the specified trait
438            changes.  This should accept ``obj`` and should return a
439            Boolean indicating whether the condition is satisfied or not.
440
441        timeout : float or None, optional
442            The amount of time in seconds to wait for the condition to
443            become true.  None can be used to indicate no timeout.
444
445        """
446        try:
447            wait_for_condition(
448                condition=condition, obj=obj, trait=trait, timeout=timeout
449            )
450        except RuntimeError:
451            # Helpful to know whether we timed out because the
452            # condition never became true, or because the expected
453            # event was never issued.
454            condition_at_timeout = condition(obj)
455            self.fail(
456                "Timed out waiting for condition. "
457                "At timeout, condition was {0}.".format(condition_at_timeout)
458            )
459
460    @contextlib.contextmanager
461    def _catch_warnings(self):
462        """
463        Replacement for warnings.catch_warnings.
464
465        This method wraps warnings.catch_warnings, takes care to
466        reset the warning registry before entering the with context,
467        and ensures that DeprecationWarnings are always emitted.
468
469        The hack to reset the warning registry is no longer needed in
470        Python 3.4 and later. See http://bugs.python.org/issue4180 for
471        more background.
472
473        .. deprecated:: 6.2
474            Use :func:`warnings.catch_warnings` instead.
475
476        """
477        warnings.warn(
478            (
479                "The _catch_warnings method is deprecated. "
480                "Use warnings.catch_warnings instead."
481            ),
482            DeprecationWarning,
483        )
484
485        # Ugly hack copied from the core Python code (see
486        # Lib/test/test_support.py) to reset the warnings registry
487        # for the module making use of this context manager.
488        #
489        # Note that this hack is unnecessary in Python 3.4 and later; see
490        # http://bugs.python.org/issue4180 for the background.
491        registry = sys._getframe(4).f_globals.get("__warningregistry__")
492        if registry:
493            registry.clear()
494
495        with warnings.catch_warnings(record=True) as w:
496            warnings.simplefilter("always", DeprecationWarning)
497            yield w
498
499    @contextlib.contextmanager
500    def assertDeprecated(self):
501        """
502        Assert that the code inside the with block is deprecated.  Intended
503        for testing uses of traits.util.deprecated.deprecated.
504
505        """
506        with warnings.catch_warnings(record=True) as w:
507            warnings.simplefilter("always", DeprecationWarning)
508            yield w
509
510        self.assertGreater(
511            len(w),
512            0,
513            msg="Expected a DeprecationWarning, " "but none was issued",
514        )
515
516    @contextlib.contextmanager
517    def assertNotDeprecated(self):
518        """
519        Assert that the code inside the with block is not deprecated.  Intended
520        for testing uses of traits.util.deprecated.deprecated.
521
522        """
523        with warnings.catch_warnings(record=True) as w:
524            warnings.simplefilter("always", DeprecationWarning)
525            yield w
526
527        self.assertEqual(
528            len(w),
529            0,
530            msg="Expected no DeprecationWarning, "
531            "but at least one was issued",
532        )
533