1# (C) Copyright 2007-2019 Enthought, Inc., Austin, TX 2# All rights reserved. 3# 4# This software is provided without warranty under the terms of the BSD 5# license included in LICENSE.txt and may be redistributed only 6# under the conditions described in the aforementioned license. The license 7# is also available online at http://www.enthought.com/licenses/BSD.txt 8# Thanks for using Enthought open source! 9""" Tests for class load hooks. """ 10 11 12from envisage.api import ClassLoadHook 13from traits.api import HasTraits 14from traits.testing.unittest_tools import unittest 15 16 17# This module's package. 18PKG = 'envisage.tests' 19 20 21class ClassLoadHookTestCase(unittest.TestCase): 22 """ Tests for class load hooks. """ 23 24 def test_connect(self): 25 """ connect """ 26 27 def on_class_loaded(cls): 28 """ Called when a class is loaded. """ 29 30 on_class_loaded.cls = cls 31 32 # To register with 'MetaHasTraits' we use 'module_name.class_name'. 33 hook = ClassLoadHook( 34 class_name = ClassLoadHookTestCase.__module__ + '.Foo', 35 on_load = on_class_loaded 36 ) 37 hook.connect() 38 39 class Foo(HasTraits): 40 pass 41 42 self.assertEqual(Foo, on_class_loaded.cls) 43 44 def test_class_already_loaded(self): 45 """ class already loaded """ 46 47 def on_class_loaded(cls): 48 """ Called when a class is loaded. """ 49 50 on_class_loaded.cls = cls 51 52 # To register with 'MetaHasTraits' we use 'module_name.class_name'. 53 hook = ClassLoadHook( 54 class_name = self._get_full_class_name(ClassLoadHookTestCase), 55 on_load = on_class_loaded 56 ) 57 hook.connect() 58 59 # Make sure the 'on_load' got called immediately because the class is 60 # already loaded. 61 self.assertEqual(ClassLoadHookTestCase, on_class_loaded.cls) 62 63 def test_disconnect(self): 64 """ disconnect """ 65 66 def on_class_loaded(cls): 67 """ Called when a class is loaded. """ 68 69 on_class_loaded.cls = cls 70 71 # To register with 'MetaHasTraits' we use 'module_name.class_name'. 72 hook = ClassLoadHook( 73 class_name = ClassLoadHookTestCase.__module__ + '.Foo', 74 on_load = on_class_loaded 75 ) 76 hook.connect() 77 78 class Foo(HasTraits): 79 pass 80 81 self.assertEqual(Foo, on_class_loaded.cls) 82 83 # 'Reset' the listener, 84 on_class_loaded.cls = None 85 86 # Now disconnect. 87 hook.disconnect() 88 89 class Foo(HasTraits): 90 pass 91 92 self.assertEqual(None, on_class_loaded.cls) 93 94 ########################################################################### 95 # Private interface. 96 ########################################################################### 97 98 def _get_full_class_name(self, cls): 99 """ Return the full (possibly) dotted name of a class. """ 100 101 return cls.__module__ + '.' + cls.__name__ 102