1# (C) Copyright 2007-2019 Enthought, Inc., Austin, TX
2# All rights reserved.
3#
4# This software is provided without warranty under the terms of the BSD
5# license included in LICENSE.txt and may be redistributed only
6# under the conditions described in the aforementioned license.  The license
7# is also available online at http://www.enthought.com/licenses/BSD.txt
8# Thanks for using Enthought open source!
9""" Tests for extension points. """
10
11
12# Enthought library imports.
13from envisage.api import Application, ExtensionPoint
14from envisage.api import ExtensionRegistry
15from traits.api import HasTraits, Int, List, TraitError
16from traits.testing.unittest_tools import unittest
17
18
19class TestBase(HasTraits):
20    """ Base class for all test classes that use the 'ExtensionPoint' type. """
21
22    extension_registry = None
23
24
25class ExtensionPointTestCase(unittest.TestCase):
26    """ Tests for extension points. """
27
28    def setUp(self):
29        """ Prepares the test fixture before each test method is called. """
30
31        # We do all of the testing via the application to make sure it offers
32        # the same interface!
33        self.registry = Application(extension_registry=ExtensionRegistry())
34
35        # Set the extension registry used by the test classes.
36        TestBase.extension_registry = self.registry
37
38    def test_invalid_extension_point_type(self):
39        """ invalid extension point type """
40
41        # Extension points currently have to be 'List's of something! The
42        # default is a list of anything.
43        with self.assertRaises(TypeError):
44            ExtensionPoint(Int, "my.ep")
45
46    def test_no_reference_to_extension_registry(self):
47        """ no reference to extension registry """
48
49        registry = self.registry
50
51        # Add an extension point.
52        registry.add_extension_point(self._create_extension_point('my.ep'))
53
54        # Set the extensions.
55        registry.set_extensions('my.ep', 'xxx')
56
57        # Declare a class that consumes the extension.
58        class Foo(HasTraits):
59            x = ExtensionPoint(List(Int), id='my.ep')
60
61        # We should get an exception because the object does not have an
62        # 'extension_registry' trait.
63        f = Foo()
64        with self.assertRaises(ValueError):
65            getattr(f, "x")
66
67    def test_extension_point_changed(self):
68        """ extension point changed """
69
70        registry = self.registry
71
72        # Add an extension point.
73        registry.add_extension_point(self._create_extension_point('my.ep'))
74
75        # Declare a class that consumes the extension.
76        class Foo(TestBase):
77            x = ExtensionPoint(id='my.ep')
78
79            def _x_changed(self):
80                """ Static trait change handler. """
81
82                self.x_changed_called = True
83
84        f = Foo()
85
86        # Connect the extension points on the object so that it can listen
87        # for changes.
88        ExtensionPoint.connect_extension_point_traits(f)
89
90        # Set the extensions.
91        registry.set_extensions('my.ep', [42, 'a string', True])
92
93        # Make sure that instances of the class pick up the extensions.
94        self.assertEqual(3, len(f.x))
95        self.assertEqual([42, 'a string', True],  f.x)
96
97        # Make sure the trait change handler was called.
98        self.assertTrue(f.x_changed_called)
99
100        # Reset the change handler flag.
101        f.x_changed_called = False
102
103        # Disconnect the extension points on the object.
104        ExtensionPoint.disconnect_extension_point_traits(f)
105
106        # Set the extensions.
107        registry.set_extensions('my.ep', [98, 99, 100])
108
109        # Make sure the trait change handler was *not* called.
110        self.assertEqual(False, f.x_changed_called)
111
112    def test_untyped_extension_point(self):
113        """ untyped extension point """
114
115        registry = self.registry
116
117        # Add an extension point.
118        registry.add_extension_point(self._create_extension_point('my.ep'))
119
120        # Set the extensions.
121        registry.set_extensions('my.ep', [42, 'a string', True])
122
123        # Declare a class that consumes the extension.
124        class Foo(TestBase):
125            x = ExtensionPoint(id='my.ep')
126
127        # Make sure that instances of the class pick up the extensions.
128        f = Foo()
129        self.assertEqual(3, len(f.x))
130        self.assertEqual([42, 'a string', True],  f.x)
131
132        g = Foo()
133        self.assertEqual(3, len(g.x))
134        self.assertEqual([42, 'a string', True],  g.x)
135
136    def test_typed_extension_point(self):
137        """ typed extension point """
138
139        registry = self.registry
140
141        # Add an extension point.
142        registry.add_extension_point(self._create_extension_point('my.ep'))
143
144        # Set the extensions.
145        registry.set_extensions('my.ep', [42, 43, 44])
146
147        # Declare a class that consumes the extension.
148        class Foo(TestBase):
149            x = ExtensionPoint(List(Int), id='my.ep')
150
151        # Make sure that instances of the class pick up the extensions.
152        f = Foo()
153        self.assertEqual(3, len(f.x))
154        self.assertEqual([42, 43, 44], f.x)
155
156        g = Foo()
157        self.assertEqual(3, len(g.x))
158        self.assertEqual([42, 43, 44], g.x)
159
160    def test_invalid_extension_point(self):
161        """ invalid extension point """
162
163        registry = self.registry
164
165        # Add an extension point.
166        registry.add_extension_point(self._create_extension_point('my.ep'))
167
168        # Set the extensions.
169        registry.set_extensions('my.ep', 'xxx')
170
171        # Declare a class that consumes the extension.
172        class Foo(TestBase):
173            x = ExtensionPoint(List(Int), id='my.ep')
174
175        # Make sure we get a trait error because the type of the extension
176        # doesn't match that of the extension point.
177        f = Foo()
178        with self.assertRaises(TraitError):
179            getattr(f, "x")
180
181    def test_extension_point_with_no_id(self):
182        """ extension point with no Id """
183
184        def factory():
185            class Foo(TestBase):
186                x = ExtensionPoint(List(Int))
187
188        with self.assertRaises(ValueError):
189            factory()
190
191    def test_set_untyped_extension_point(self):
192        """ set untyped extension point """
193
194        registry = self.registry
195
196        # Add an extension point.
197        registry.add_extension_point(self._create_extension_point('my.ep'))
198
199        # Declare a class that consumes the extension.
200        class Foo(TestBase):
201            x = ExtensionPoint(id='my.ep')
202
203        # Make sure that when we set the trait the extension registry gets
204        # updated.
205        f = Foo()
206        f.x = [42]
207
208        self.assertEqual([42], registry.get_extensions('my.ep'))
209
210    def test_set_typed_extension_point(self):
211        """ set typed extension point """
212
213        registry = self.registry
214
215        # Add an extension point.
216        registry.add_extension_point(self._create_extension_point('my.ep'))
217
218        # Declare a class that consumes the extension.
219        class Foo(TestBase):
220            x = ExtensionPoint(List(Int), id='my.ep')
221
222        # Make sure that when we set the trait the extension registry gets
223        # updated.
224        f = Foo()
225        f.x = [42]
226
227        self.assertEqual([42], registry.get_extensions('my.ep'))
228
229    def test_extension_point_str_representation(self):
230        """ test the string representation of the extension point """
231        ep_repr = "ExtensionPoint(id={})"
232        ep = self._create_extension_point('my.ep')
233        self.assertEqual(ep_repr.format('my.ep'), str(ep))
234        self.assertEqual(ep_repr.format('my.ep'), repr(ep))
235
236    ###########################################################################
237    # Private interface.
238    ###########################################################################
239
240    def _create_extension_point(self, id, trait_type=List, desc=''):
241        """ Create an extension point. """
242
243        return ExtensionPoint(id=id, trait_type=trait_type, desc=desc)
244