1# (C) Copyright 2007-2019 Enthought, Inc., Austin, TX
2# All rights reserved.
3#
4# This software is provided without warranty under the terms of the BSD
5# license included in LICENSE.txt and may be redistributed only
6# under the conditions described in the aforementioned license.  The license
7# is also available online at http://www.enthought.com/licenses/BSD.txt
8# Thanks for using Enthought open source!
9""" Tests for the base extension registry. """
10
11
12# Enthought library imports.
13from envisage.api import Application, ExtensionPoint
14from envisage.api import ExtensionRegistry, UnknownExtensionPoint
15from traits.api import List
16from traits.testing.unittest_tools import unittest
17
18
19class ExtensionRegistryTestCase(unittest.TestCase):
20    """ Tests for the base extension registry. """
21
22    def setUp(self):
23        """ Prepares the test fixture before each test method is called. """
24
25        # We do all of the testing via the application to make sure it offers
26        # the same interface!
27        self.registry = Application(extension_registry=ExtensionRegistry())
28
29    def test_empty_registry(self):
30        """ empty registry """
31
32        registry = self.registry
33
34        # Make sure there are no extensions.
35        extensions = registry.get_extensions('my.ep')
36        self.assertEqual(0, len(extensions))
37
38        # Make sure there are no extension points.
39        extension_points = registry.get_extension_points()
40        self.assertEqual(0, len(extension_points))
41
42    def test_add_extension_point(self):
43        """ add extension point """
44
45        registry = self.registry
46
47        # Add an extension *point*.
48        registry.add_extension_point(self._create_extension_point('my.ep'))
49
50        # Make sure there's NO extensions.
51        extensions = registry.get_extensions('my.ep')
52        self.assertEqual(0, len(extensions))
53
54        # Make sure there's one and only one extension point.
55        extension_points = registry.get_extension_points()
56        self.assertEqual(1, len(extension_points))
57        self.assertEqual('my.ep', extension_points[0].id)
58
59    def test_get_extension_point(self):
60        """ get extension point """
61
62        registry = self.registry
63
64        # Add an extension *point*.
65        registry.add_extension_point(self._create_extension_point('my.ep'))
66
67        # Make sure we can get it.
68        extension_point = registry.get_extension_point('my.ep')
69        self.assertNotEqual(None, extension_point)
70        self.assertEqual('my.ep', extension_point.id)
71
72    def test_remove_empty_extension_point(self):
73        """ remove empty_extension point """
74
75        registry = self.registry
76
77        # Add an extension point...
78        registry.add_extension_point(self._create_extension_point('my.ep'))
79
80        # ...and remove it!
81        registry.remove_extension_point('my.ep')
82
83        # Make sure there are no extension points.
84        extension_points = registry.get_extension_points()
85        self.assertEqual(0, len(extension_points))
86
87    def test_remove_non_empty_extension_point(self):
88        """ remove non-empty extension point """
89
90        registry = self.registry
91
92        # Add an extension point...
93        registry.add_extension_point(self._create_extension_point('my.ep'))
94
95        # ... with some extensions...
96        registry.set_extensions('my.ep', [42])
97
98        # ...and remove it!
99        registry.remove_extension_point('my.ep')
100
101        # Make sure there are no extension points.
102        extension_points = registry.get_extension_points()
103        self.assertEqual(0, len(extension_points))
104
105        # And that the extensions are gone too.
106        self.assertEqual([], registry.get_extensions('my.ep'))
107
108    def test_remove_non_existent_extension_point(self):
109        """ remove non existent extension point """
110
111        registry = self.registry
112
113        with self.assertRaises(UnknownExtensionPoint):
114            registry.remove_extension_point("my.ep")
115
116    def test_remove_non_existent_listener(self):
117        """ remove non existent listener """
118
119        registry = self.registry
120
121        def listener(registry, extension_point, added, removed, index):
122            """ Called when an extension point has changed. """
123
124            self.listener_called = (registry, extension_point, added, removed)
125
126        with self.assertRaises(ValueError):
127            registry.remove_extension_point_listener(listener)
128
129    def test_set_extensions(self):
130        """ set extensions """
131
132        registry = self.registry
133
134        # Add an extension *point*.
135        registry.add_extension_point(self._create_extension_point('my.ep'))
136
137        # Set some extensions.
138        registry.set_extensions('my.ep', [1, 2, 3])
139
140        # Make sure we can get them.
141        self.assertEqual([1, 2, 3], registry.get_extensions('my.ep'))
142
143    ###########################################################################
144    # Private interface.
145    ###########################################################################
146
147    def _create_extension_point(self, id, trait_type=List, desc=''):
148        """ Create an extension point. """
149
150        return ExtensionPoint(id=id, trait_type=trait_type, desc=desc)
151