1# (C) Copyright 2007-2019 Enthought, Inc., Austin, TX
2# All rights reserved.
3#
4# This software is provided without warranty under the terms of the BSD
5# license included in LICENSE.txt and may be redistributed only
6# under the conditions described in the aforementioned license.  The license
7# is also available online at http://www.enthought.com/licenses/BSD.txt
8# Thanks for using Enthought open source!
9""" Tests for class load hooks. """
10
11
12from envisage.api import ClassLoadHook
13from traits.api import HasTraits
14from traits.testing.unittest_tools import unittest
15
16
17# This module's package.
18PKG = 'envisage.tests'
19
20
21class ClassLoadHookTestCase(unittest.TestCase):
22    """ Tests for class load hooks. """
23
24    def test_connect(self):
25        """ connect """
26
27        def on_class_loaded(cls):
28            """ Called when a class is loaded. """
29
30            on_class_loaded.cls = cls
31
32        # To register with 'MetaHasTraits' we use 'module_name.class_name'.
33        hook = ClassLoadHook(
34            class_name = ClassLoadHookTestCase.__module__ + '.Foo',
35            on_load    = on_class_loaded
36        )
37        hook.connect()
38
39        class Foo(HasTraits):
40            pass
41
42        self.assertEqual(Foo, on_class_loaded.cls)
43
44    def test_class_already_loaded(self):
45        """ class already loaded """
46
47        def on_class_loaded(cls):
48            """ Called when a class is loaded. """
49
50            on_class_loaded.cls = cls
51
52        # To register with 'MetaHasTraits' we use 'module_name.class_name'.
53        hook = ClassLoadHook(
54            class_name = self._get_full_class_name(ClassLoadHookTestCase),
55            on_load    = on_class_loaded
56        )
57        hook.connect()
58
59        # Make sure the 'on_load' got called immediately because the class is
60        # already loaded.
61        self.assertEqual(ClassLoadHookTestCase, on_class_loaded.cls)
62
63    def test_disconnect(self):
64        """ disconnect """
65
66        def on_class_loaded(cls):
67            """ Called when a class is loaded. """
68
69            on_class_loaded.cls = cls
70
71        # To register with 'MetaHasTraits' we use 'module_name.class_name'.
72        hook = ClassLoadHook(
73            class_name = ClassLoadHookTestCase.__module__ + '.Foo',
74            on_load    = on_class_loaded
75        )
76        hook.connect()
77
78        class Foo(HasTraits):
79            pass
80
81        self.assertEqual(Foo, on_class_loaded.cls)
82
83        # 'Reset' the listener,
84        on_class_loaded.cls = None
85
86        # Now disconnect.
87        hook.disconnect()
88
89        class Foo(HasTraits):
90            pass
91
92        self.assertEqual(None, on_class_loaded.cls)
93
94    ###########################################################################
95    # Private interface.
96    ###########################################################################
97
98    def _get_full_class_name(self, cls):
99        """ Return the full (possibly) dotted name of a class. """
100
101        return cls.__module__ + '.' + cls.__name__
102