1# (C) Copyright 2007-2019 Enthought, Inc., Austin, TX
2# All rights reserved.
3#
4# This software is provided without warranty under the terms of the BSD
5# license included in LICENSE.txt and may be redistributed only
6# under the conditions described in the aforementioned license.  The license
7# is also available online at http://www.enthought.com/licenses/BSD.txt
8# Thanks for using Enthought open source!
9""" Tests for extension point bindings. """
10
11
12# Enthought library imports.
13from envisage.api import ExtensionPoint
14from envisage.api import bind_extension_point
15from traits.api import HasTraits, List
16from traits.testing.unittest_tools import unittest
17
18# Local imports.
19from envisage.tests.mutable_extension_registry import (
20    MutableExtensionRegistry
21)
22
23
24def listener(obj, trait_name, old, new):
25    """ A useful trait change handler for testing! """
26
27    listener.obj = obj
28    listener.trait_name = trait_name
29    listener.old = old
30    listener.new = new
31
32
33class ExtensionPointBindingTestCase(unittest.TestCase):
34    """ Tests for extension point binding. """
35
36    def setUp(self):
37        """ Prepares the test fixture before each test method is called. """
38
39        self.extension_registry = MutableExtensionRegistry()
40
41        # Use the extension registry for all extension points and bindings.
42        ExtensionPoint.extension_registry = self.extension_registry
43
44    def test_untyped_extension_point(self):
45        """ untyped extension point """
46
47        registry = self.extension_registry
48
49        # Add an extension point.
50        registry.add_extension_point(self._create_extension_point('my.ep'))
51
52        # Add an extension.
53        registry.add_extension('my.ep', 42)
54
55        # Declare a class that consumes the extension.
56        class Foo(HasTraits):
57            x = List
58
59        f = Foo()
60        f.on_trait_change(listener)
61
62        # Make some bindings.
63        bind_extension_point(f, 'x', 'my.ep')
64
65        # Make sure that the object was initialized properly.
66        self.assertEqual(1, len(f.x))
67        self.assertEqual(42, f.x[0])
68
69        # Add another extension.
70        registry.add_extension('my.ep', 'a string')
71
72        # Make sure that the object picked up the new extension...
73        self.assertEqual(2, len(f.x))
74        self.assertTrue(42 in f.x)
75        self.assertTrue('a string' in f.x)
76
77        # ... and that the correct trait change event was fired.
78        self.assertEqual(f, listener.obj)
79        self.assertEqual('x_items', listener.trait_name)
80        self.assertEqual(1, len(listener.new.added))
81        self.assertTrue('a string' in listener.new.added)
82
83    def test_set_extensions_via_trait(self):
84        """ set extensions via trait """
85
86        registry = self.extension_registry
87
88        # Add an extension point.
89        registry.add_extension_point(self._create_extension_point('my.ep'))
90
91        # Add an extension.
92        registry.add_extension('my.ep', 42)
93
94        # Declare a class that consumes the extension.
95        class Foo(HasTraits):
96            x = List
97
98        f = Foo()
99        f.on_trait_change(listener)
100
101        # Make some bindings.
102        bind_extension_point(f, 'x', 'my.ep')
103
104        # Make sure that the object was initialized properly.
105        self.assertEqual(1, len(f.x))
106        self.assertEqual(42, f.x[0])
107
108        # Set the extensions.
109        f.x = ['a string']
110
111        # Make sure that the object picked up the new extension...
112        self.assertEqual(1, len(f.x))
113        self.assertTrue('a string' in f.x)
114
115        self.assertEqual(1, len(registry.get_extensions('my.ep')))
116        self.assertTrue('a string' in registry.get_extensions('my.ep'))
117
118        # ... and that the correct trait change event was fired.
119        self.assertEqual(f, listener.obj)
120        self.assertEqual('x', listener.trait_name)
121        self.assertEqual(1, len(listener.new))
122        self.assertTrue('a string' in listener.new)
123
124    def test_set_extensions_via_registry(self):
125        """ set extensions via registry """
126
127        registry = self.extension_registry
128
129        # Add an extension point.
130        registry.add_extension_point(self._create_extension_point('my.ep'))
131
132        # Add an extension.
133        registry.add_extension('my.ep', 42)
134
135        # Declare a class that consumes the extension.
136        class Foo(HasTraits):
137            x = List
138
139        f = Foo()
140        f.on_trait_change(listener)
141
142        # Make some bindings.
143        bind_extension_point(f, 'x', 'my.ep')
144
145        # Make sure that the object was initialized properly.
146        self.assertEqual(1, len(f.x))
147        self.assertEqual(42, f.x[0])
148
149        # Set the extensions.
150        registry.set_extensions('my.ep', ['a string'])
151
152        # Make sure that the object picked up the new extension...
153        self.assertEqual(1, len(f.x))
154        self.assertTrue('a string' in f.x)
155
156        # ... and that the correct trait change event was fired.
157        self.assertEqual(f, listener.obj)
158        self.assertEqual('x', listener.trait_name)
159        self.assertEqual(1, len(listener.new))
160        self.assertTrue('a string' in listener.new)
161
162    def test_explicit_extension_registry(self):
163        """ explicit extension registry """
164
165        registry = self.extension_registry
166
167        # Add an extension point.
168        registry.add_extension_point(self._create_extension_point('my.ep'))
169
170        # Add an extension.
171        registry.add_extension('my.ep', 42)
172
173        # Declare a class that consumes the extension.
174        class Foo(HasTraits):
175            x = List
176
177        f = Foo()
178        f.on_trait_change(listener)
179
180        # Create an empty extension registry use that in the binding.
181        extension_registry = MutableExtensionRegistry()
182
183        # Make some bindings.
184        bind_extension_point(f, 'x', 'my.ep', extension_registry)
185
186        # Make sure that we pick up the empty extension registry and not the
187        # default one.
188        self.assertEqual(0, len(f.x))
189
190    def test_should_be_able_to_bind_multiple_traits_on_a_single_object(self):
191
192        registry = self.extension_registry
193
194        # Add 2 extension points.
195        registry.add_extension_point(self._create_extension_point('my.ep'))
196        registry.add_extension_point(self._create_extension_point('another.ep'))
197
198        # Declare a class that consumes both of the extension points.
199        class Foo(HasTraits):
200            x = List
201            y = List
202
203        f = Foo()
204
205        # Bind two different traits on the object to the extension points.
206        bind_extension_point(f, 'x', 'my.ep', registry)
207        bind_extension_point(f, 'y', 'another.ep', registry)
208        self.assertEqual(0, len(f.x))
209        self.assertEqual(0, len(f.y))
210
211        # Add some contributions to the extension points.
212        registry.add_extension('my.ep', 42)
213        registry.add_extensions('another.ep', [98, 99, 100])
214
215        # Make sure both traits were bound correctly.
216        self.assertEqual(1, len(f.x))
217        self.assertEqual(3, len(f.y))
218
219    ###########################################################################
220    # Private interface.
221    ###########################################################################
222
223    def _create_extension_point(self, id, trait_type=List, desc=''):
224        """ Create an extension point. """
225
226        return ExtensionPoint(id=id, trait_type=trait_type, desc=desc)
227