1# (C) Copyright 2007-2019 Enthought, Inc., Austin, TX
2# All rights reserved.
3#
4# This software is provided without warranty under the terms of the BSD
5# license included in LICENSE.txt and may be redistributed only
6# under the conditions described in the aforementioned license.  The license
7# is also available online at http://www.enthought.com/licenses/BSD.txt
8# Thanks for using Enthought open source!
9""" An extension registry implementation with multiple providers. """
10
11
12# Standard library imports.
13import logging
14
15# Enthought library imports.
16from traits.api import List, provides, on_trait_change
17
18# Local imports.
19from .extension_registry import ExtensionRegistry
20from .i_extension_provider import IExtensionProvider
21from .i_provider_extension_registry import IProviderExtensionRegistry
22
23
24# Logging.
25logger = logging.getLogger(__name__)
26
27
28@provides(IProviderExtensionRegistry)
29class ProviderExtensionRegistry(ExtensionRegistry):
30    """ An extension registry implementation with multiple providers. """
31
32    #### Protected 'ProviderExtensionRegistry' interface ######################
33
34    # The extension providers that populate the registry.
35    _providers = List(IExtensionProvider)
36
37    ###########################################################################
38    # 'IExtensionRegistry' interface.
39    ###########################################################################
40
41    def set_extensions(self, extension_point_id, extensions):
42        """ Set the extensions to an extension point. """
43
44        raise SystemError('extension points cannot be set')
45
46    ###########################################################################
47    # 'ProviderExtensionRegistry' interface.
48    ###########################################################################
49
50    def add_provider(self, provider):
51        """ Add an extension provider. """
52
53        events = self._add_provider(provider)
54
55        for extension_point_id, (refs, added, index) in events.items():
56            self._call_listeners(refs, extension_point_id, added, [], index)
57
58        return
59
60    def get_providers(self):
61        """ Return all of the providers in the registry. """
62
63        return self._providers[:]
64
65    def remove_provider(self, provider):
66        """ Remove an extension provider.
67
68        Raise a 'ValueError' if the provider is not in the registry.
69
70        """
71
72        events = self._remove_provider(provider)
73
74        for extension_point_id, (refs, removed, index) in events.items():
75            self._call_listeners(refs, extension_point_id, [], removed, index)
76
77        return
78
79    ###########################################################################
80    # Protected 'ExtensionRegistry' interface.
81    ###########################################################################
82
83    def _get_extensions(self, extension_point_id):
84        """ Return the extensions for the given extension point. """
85
86        # If we don't know about the extension point then it sure ain't got
87        # any extensions!
88        if not extension_point_id in self._extension_points:
89            logger.warning(
90                'getting extensions of unknown extension point <%s>' \
91                % extension_point_id
92            )
93            extensions = []
94
95        # Has this extension point already been accessed?
96        elif extension_point_id in self._extensions:
97            extensions = self._extensions[extension_point_id]
98
99        # If not, then ask each provider for its contributions to the extension
100        # point.
101        else:
102            extensions = self._initialize_extensions(extension_point_id)
103            self._extensions[extension_point_id] = extensions
104
105        # We store the extensions as a list of lists, with each inner list
106        # containing the contributions from a single provider. Here we just
107        # concatenate them into a single list.
108        #
109        # You could use a list comprehension, here:-
110        #
111        #     all = [x for y in extensions for x in y]
112        #
113        # But I'm sure that that makes it any clearer ;^)
114
115        all = []
116        for extensions_of_single_provider in extensions:
117            all.extend(extensions_of_single_provider)
118        return all
119
120    ###########################################################################
121    # Protected 'ProviderExtensionRegistry' interface.
122    ###########################################################################
123
124    def _add_provider(self, provider):
125        """ Add a new provider. """
126
127        # Add the provider's extension points.
128        self._add_provider_extension_points(provider)
129
130        # Add the provider's extensions.
131        events = self._add_provider_extensions(provider)
132
133        # And finally, tag it into the list of providers.
134        self._providers.append(provider)
135
136        return events
137
138    def _add_provider_extensions(self, provider):
139        """ Add a provider's extensions to the registry. """
140
141        # Each provider can contribute to multiple extension points, so we
142        # build up a dictionary of the 'ExtensionPointChanged' events that we
143        # need to fire.
144        events = {}
145
146        # Does the provider contribute any extensions to an extension point
147        # that has already been accessed?
148
149        for extension_point_id, extensions in self._extensions.items():
150            new = provider.get_extensions(extension_point_id)
151
152            # We only need fire an event for this extension point if the
153            # provider contributes any extensions.
154            if len(new) > 0:
155                index = sum(map(len, extensions))
156                refs  = self._get_listener_refs(extension_point_id)
157                events[extension_point_id] = (refs, new[:], index)
158
159            extensions.append(new)
160
161        return events
162
163    def _add_provider_extension_points(self, provider):
164        """ Add a provider's extension points to the registry. """
165
166        for extension_point in provider.get_extension_points():
167            self._extension_points[extension_point.id] = extension_point
168
169        return
170
171    def _remove_provider(self, provider):
172        """ Remove a provider. """
173
174        # Remove the provider's extensions.
175        events = self._remove_provider_extensions(provider)
176
177        # Remove the provider's extension points.
178        self._remove_provider_extension_points(provider, events)
179
180        # And finally take it out of the list of providers.
181        self._providers.remove(provider)
182
183        return events
184
185    def _remove_provider_extensions(self, provider):
186        """ Remove a provider's extensions from the registry. """
187
188        # Each provider can contribute to multiple extension points, so we
189        # build up a dictionary of the 'ExtensionPointChanged' events that we
190        # need to fire.
191        events = {}
192
193        # Find the index of the provider in the provider list. Its
194        # contributions are at the same index in the extensions list of lists.
195        index = self._providers.index(provider)
196
197        # Does the provider contribute any extensions to an extension point
198        # that has already been accessed?
199        for extension_point_id, extensions in self._extensions.items():
200            old = extensions[index]
201
202            # We only need fire an event for this extension point if the
203            # provider contributed any extensions.
204            if len(old) > 0:
205                offset = sum(map(len, extensions[:index]))
206                refs  = self._get_listener_refs(extension_point_id)
207                events[extension_point_id] = (refs, old[:], offset)
208
209            del extensions[index]
210
211        return events
212
213    def _remove_provider_extension_points(self, provider, events):
214        """ Remove a provider's extension points from the registry. """
215
216        for extension_point in provider.get_extension_points():
217            # Remove the extension point.
218            del self._extension_points[extension_point.id]
219
220        return
221
222    ###########################################################################
223    # Private interface.
224    ###########################################################################
225
226    #### Trait change handlers ################################################
227
228    @on_trait_change('_providers:extension_point_changed')
229    def _providers_extension_point_changed(self, obj, trait_name, old, event):
230        """ Dynamic trait change handler. """
231
232        logger.debug('provider <%s> extension point changed', obj)
233
234        extension_point_id = event.extension_point_id
235
236        # If the extension point has not yet been accessed then we don't fire a
237        # changed event.
238        #
239        # This is because we only access extension points lazily and so we
240        # can't tell what has actually changed because we have nothing to
241        # compare it to!
242        if not extension_point_id in self._extensions:
243            return
244
245        # This is a list of lists where each inner list contains the
246        # contributions made to the extension point by a single provider.
247        #
248        # fixme: This causes a problem if the extension point has not yet been
249        # accessed! The tricky thing is that if it hasn't been accessed yet
250        # how do we know what has changed?!? Maybe we should just return an
251        # empty list instead of barfing!
252        extensions = self._extensions[extension_point_id]
253
254        # Find the index of the provider in the provider list. Its
255        # contributions are at the same index in the extensions list of lists.
256        provider_index = self._providers.index(obj)
257
258        # Get the updated list from the provider.
259        extensions[provider_index] = obj.get_extensions(extension_point_id)
260
261        # Find where the provider's contributions are in the whole 'list'.
262        offset = sum(map(len, extensions[:provider_index]))
263
264        # Translate the event index from one that refers to the list of
265        # contributions from the provider, to the list of contributions from
266        # all providers.
267        index = self._translate_index(event.index, offset)
268
269        # Find out who is listening.
270        refs = self._get_listener_refs(extension_point_id)
271
272        # Let any listeners know that the extensions have been added.
273        self._call_listeners(
274            refs, extension_point_id, event.added, event.removed, index
275        )
276
277        return
278
279    #### Methods ##############################################################
280
281    def _initialize_extensions(self, extension_point_id):
282        """ Initialize the extensions to an extension point. """
283
284        # We store the extensions as a list of lists, with each inner list
285        # containing the contributions from a single provider.
286        extensions = []
287        for provider in self._providers:
288            extensions.append(provider.get_extensions(extension_point_id)[:])
289
290        logger.debug('extensions to <%s> <%s>', extension_point_id, extensions)
291
292        return extensions
293
294    def _translate_index(self, index, offset):
295        """ Translate an event index by the given offset. """
296
297        if isinstance(index, slice):
298            index = slice(index.start+offset, index.stop+offset, index.step)
299
300        else:
301            index = index + offset
302
303        return index
304
305#### EOF ######################################################################
306