1# (C) Copyright 2007-2019 Enthought, Inc., Austin, TX 2# All rights reserved. 3# 4# This software is provided without warranty under the terms of the BSD 5# license included in LICENSE.txt and may be redistributed only 6# under the conditions described in the aforementioned license. The license 7# is also available online at http://www.enthought.com/licenses/BSD.txt 8# Thanks for using Enthought open source! 9""" Tests for extension points. """ 10 11 12# Enthought library imports. 13from envisage.api import Application, ExtensionPoint 14from envisage.api import ExtensionRegistry 15from traits.api import HasTraits, Int, List, TraitError 16from traits.testing.unittest_tools import unittest 17 18 19class TestBase(HasTraits): 20 """ Base class for all test classes that use the 'ExtensionPoint' type. """ 21 22 extension_registry = None 23 24 25class ExtensionPointTestCase(unittest.TestCase): 26 """ Tests for extension points. """ 27 28 def setUp(self): 29 """ Prepares the test fixture before each test method is called. """ 30 31 # We do all of the testing via the application to make sure it offers 32 # the same interface! 33 self.registry = Application(extension_registry=ExtensionRegistry()) 34 35 # Set the extension registry used by the test classes. 36 TestBase.extension_registry = self.registry 37 38 def test_invalid_extension_point_type(self): 39 """ invalid extension point type """ 40 41 # Extension points currently have to be 'List's of something! The 42 # default is a list of anything. 43 with self.assertRaises(TypeError): 44 ExtensionPoint(Int, "my.ep") 45 46 def test_no_reference_to_extension_registry(self): 47 """ no reference to extension registry """ 48 49 registry = self.registry 50 51 # Add an extension point. 52 registry.add_extension_point(self._create_extension_point('my.ep')) 53 54 # Set the extensions. 55 registry.set_extensions('my.ep', 'xxx') 56 57 # Declare a class that consumes the extension. 58 class Foo(HasTraits): 59 x = ExtensionPoint(List(Int), id='my.ep') 60 61 # We should get an exception because the object does not have an 62 # 'extension_registry' trait. 63 f = Foo() 64 with self.assertRaises(ValueError): 65 getattr(f, "x") 66 67 def test_extension_point_changed(self): 68 """ extension point changed """ 69 70 registry = self.registry 71 72 # Add an extension point. 73 registry.add_extension_point(self._create_extension_point('my.ep')) 74 75 # Declare a class that consumes the extension. 76 class Foo(TestBase): 77 x = ExtensionPoint(id='my.ep') 78 79 def _x_changed(self): 80 """ Static trait change handler. """ 81 82 self.x_changed_called = True 83 84 f = Foo() 85 86 # Connect the extension points on the object so that it can listen 87 # for changes. 88 ExtensionPoint.connect_extension_point_traits(f) 89 90 # Set the extensions. 91 registry.set_extensions('my.ep', [42, 'a string', True]) 92 93 # Make sure that instances of the class pick up the extensions. 94 self.assertEqual(3, len(f.x)) 95 self.assertEqual([42, 'a string', True], f.x) 96 97 # Make sure the trait change handler was called. 98 self.assertTrue(f.x_changed_called) 99 100 # Reset the change handler flag. 101 f.x_changed_called = False 102 103 # Disconnect the extension points on the object. 104 ExtensionPoint.disconnect_extension_point_traits(f) 105 106 # Set the extensions. 107 registry.set_extensions('my.ep', [98, 99, 100]) 108 109 # Make sure the trait change handler was *not* called. 110 self.assertEqual(False, f.x_changed_called) 111 112 def test_untyped_extension_point(self): 113 """ untyped extension point """ 114 115 registry = self.registry 116 117 # Add an extension point. 118 registry.add_extension_point(self._create_extension_point('my.ep')) 119 120 # Set the extensions. 121 registry.set_extensions('my.ep', [42, 'a string', True]) 122 123 # Declare a class that consumes the extension. 124 class Foo(TestBase): 125 x = ExtensionPoint(id='my.ep') 126 127 # Make sure that instances of the class pick up the extensions. 128 f = Foo() 129 self.assertEqual(3, len(f.x)) 130 self.assertEqual([42, 'a string', True], f.x) 131 132 g = Foo() 133 self.assertEqual(3, len(g.x)) 134 self.assertEqual([42, 'a string', True], g.x) 135 136 def test_typed_extension_point(self): 137 """ typed extension point """ 138 139 registry = self.registry 140 141 # Add an extension point. 142 registry.add_extension_point(self._create_extension_point('my.ep')) 143 144 # Set the extensions. 145 registry.set_extensions('my.ep', [42, 43, 44]) 146 147 # Declare a class that consumes the extension. 148 class Foo(TestBase): 149 x = ExtensionPoint(List(Int), id='my.ep') 150 151 # Make sure that instances of the class pick up the extensions. 152 f = Foo() 153 self.assertEqual(3, len(f.x)) 154 self.assertEqual([42, 43, 44], f.x) 155 156 g = Foo() 157 self.assertEqual(3, len(g.x)) 158 self.assertEqual([42, 43, 44], g.x) 159 160 def test_invalid_extension_point(self): 161 """ invalid extension point """ 162 163 registry = self.registry 164 165 # Add an extension point. 166 registry.add_extension_point(self._create_extension_point('my.ep')) 167 168 # Set the extensions. 169 registry.set_extensions('my.ep', 'xxx') 170 171 # Declare a class that consumes the extension. 172 class Foo(TestBase): 173 x = ExtensionPoint(List(Int), id='my.ep') 174 175 # Make sure we get a trait error because the type of the extension 176 # doesn't match that of the extension point. 177 f = Foo() 178 with self.assertRaises(TraitError): 179 getattr(f, "x") 180 181 def test_extension_point_with_no_id(self): 182 """ extension point with no Id """ 183 184 def factory(): 185 class Foo(TestBase): 186 x = ExtensionPoint(List(Int)) 187 188 with self.assertRaises(ValueError): 189 factory() 190 191 def test_set_untyped_extension_point(self): 192 """ set untyped extension point """ 193 194 registry = self.registry 195 196 # Add an extension point. 197 registry.add_extension_point(self._create_extension_point('my.ep')) 198 199 # Declare a class that consumes the extension. 200 class Foo(TestBase): 201 x = ExtensionPoint(id='my.ep') 202 203 # Make sure that when we set the trait the extension registry gets 204 # updated. 205 f = Foo() 206 f.x = [42] 207 208 self.assertEqual([42], registry.get_extensions('my.ep')) 209 210 def test_set_typed_extension_point(self): 211 """ set typed extension point """ 212 213 registry = self.registry 214 215 # Add an extension point. 216 registry.add_extension_point(self._create_extension_point('my.ep')) 217 218 # Declare a class that consumes the extension. 219 class Foo(TestBase): 220 x = ExtensionPoint(List(Int), id='my.ep') 221 222 # Make sure that when we set the trait the extension registry gets 223 # updated. 224 f = Foo() 225 f.x = [42] 226 227 self.assertEqual([42], registry.get_extensions('my.ep')) 228 229 def test_extension_point_str_representation(self): 230 """ test the string representation of the extension point """ 231 ep_repr = "ExtensionPoint(id={})" 232 ep = self._create_extension_point('my.ep') 233 self.assertEqual(ep_repr.format('my.ep'), str(ep)) 234 self.assertEqual(ep_repr.format('my.ep'), repr(ep)) 235 236 ########################################################################### 237 # Private interface. 238 ########################################################################### 239 240 def _create_extension_point(self, id, trait_type=List, desc=''): 241 """ Create an extension point. """ 242 243 return ExtensionPoint(id=id, trait_type=trait_type, desc=desc) 244