1# (C) Copyright 2007-2019 Enthought, Inc., Austin, TX 2# All rights reserved. 3# 4# This software is provided without warranty under the terms of the BSD 5# license included in LICENSE.txt and may be redistributed only 6# under the conditions described in the aforementioned license. The license 7# is also available online at http://www.enthought.com/licenses/BSD.txt 8# Thanks for using Enthought open source! 9""" An extension registry implementation with multiple providers. """ 10 11 12# Standard library imports. 13import logging 14 15# Enthought library imports. 16from traits.api import List, provides, on_trait_change 17 18# Local imports. 19from .extension_registry import ExtensionRegistry 20from .i_extension_provider import IExtensionProvider 21from .i_provider_extension_registry import IProviderExtensionRegistry 22 23 24# Logging. 25logger = logging.getLogger(__name__) 26 27 28@provides(IProviderExtensionRegistry) 29class ProviderExtensionRegistry(ExtensionRegistry): 30 """ An extension registry implementation with multiple providers. """ 31 32 #### Protected 'ProviderExtensionRegistry' interface ###################### 33 34 # The extension providers that populate the registry. 35 _providers = List(IExtensionProvider) 36 37 ########################################################################### 38 # 'IExtensionRegistry' interface. 39 ########################################################################### 40 41 def set_extensions(self, extension_point_id, extensions): 42 """ Set the extensions to an extension point. """ 43 44 raise SystemError('extension points cannot be set') 45 46 ########################################################################### 47 # 'ProviderExtensionRegistry' interface. 48 ########################################################################### 49 50 def add_provider(self, provider): 51 """ Add an extension provider. """ 52 53 events = self._add_provider(provider) 54 55 for extension_point_id, (refs, added, index) in events.items(): 56 self._call_listeners(refs, extension_point_id, added, [], index) 57 58 return 59 60 def get_providers(self): 61 """ Return all of the providers in the registry. """ 62 63 return self._providers[:] 64 65 def remove_provider(self, provider): 66 """ Remove an extension provider. 67 68 Raise a 'ValueError' if the provider is not in the registry. 69 70 """ 71 72 events = self._remove_provider(provider) 73 74 for extension_point_id, (refs, removed, index) in events.items(): 75 self._call_listeners(refs, extension_point_id, [], removed, index) 76 77 return 78 79 ########################################################################### 80 # Protected 'ExtensionRegistry' interface. 81 ########################################################################### 82 83 def _get_extensions(self, extension_point_id): 84 """ Return the extensions for the given extension point. """ 85 86 # If we don't know about the extension point then it sure ain't got 87 # any extensions! 88 if not extension_point_id in self._extension_points: 89 logger.warning( 90 'getting extensions of unknown extension point <%s>' \ 91 % extension_point_id 92 ) 93 extensions = [] 94 95 # Has this extension point already been accessed? 96 elif extension_point_id in self._extensions: 97 extensions = self._extensions[extension_point_id] 98 99 # If not, then ask each provider for its contributions to the extension 100 # point. 101 else: 102 extensions = self._initialize_extensions(extension_point_id) 103 self._extensions[extension_point_id] = extensions 104 105 # We store the extensions as a list of lists, with each inner list 106 # containing the contributions from a single provider. Here we just 107 # concatenate them into a single list. 108 # 109 # You could use a list comprehension, here:- 110 # 111 # all = [x for y in extensions for x in y] 112 # 113 # But I'm sure that that makes it any clearer ;^) 114 115 all = [] 116 for extensions_of_single_provider in extensions: 117 all.extend(extensions_of_single_provider) 118 return all 119 120 ########################################################################### 121 # Protected 'ProviderExtensionRegistry' interface. 122 ########################################################################### 123 124 def _add_provider(self, provider): 125 """ Add a new provider. """ 126 127 # Add the provider's extension points. 128 self._add_provider_extension_points(provider) 129 130 # Add the provider's extensions. 131 events = self._add_provider_extensions(provider) 132 133 # And finally, tag it into the list of providers. 134 self._providers.append(provider) 135 136 return events 137 138 def _add_provider_extensions(self, provider): 139 """ Add a provider's extensions to the registry. """ 140 141 # Each provider can contribute to multiple extension points, so we 142 # build up a dictionary of the 'ExtensionPointChanged' events that we 143 # need to fire. 144 events = {} 145 146 # Does the provider contribute any extensions to an extension point 147 # that has already been accessed? 148 149 for extension_point_id, extensions in self._extensions.items(): 150 new = provider.get_extensions(extension_point_id) 151 152 # We only need fire an event for this extension point if the 153 # provider contributes any extensions. 154 if len(new) > 0: 155 index = sum(map(len, extensions)) 156 refs = self._get_listener_refs(extension_point_id) 157 events[extension_point_id] = (refs, new[:], index) 158 159 extensions.append(new) 160 161 return events 162 163 def _add_provider_extension_points(self, provider): 164 """ Add a provider's extension points to the registry. """ 165 166 for extension_point in provider.get_extension_points(): 167 self._extension_points[extension_point.id] = extension_point 168 169 return 170 171 def _remove_provider(self, provider): 172 """ Remove a provider. """ 173 174 # Remove the provider's extensions. 175 events = self._remove_provider_extensions(provider) 176 177 # Remove the provider's extension points. 178 self._remove_provider_extension_points(provider, events) 179 180 # And finally take it out of the list of providers. 181 self._providers.remove(provider) 182 183 return events 184 185 def _remove_provider_extensions(self, provider): 186 """ Remove a provider's extensions from the registry. """ 187 188 # Each provider can contribute to multiple extension points, so we 189 # build up a dictionary of the 'ExtensionPointChanged' events that we 190 # need to fire. 191 events = {} 192 193 # Find the index of the provider in the provider list. Its 194 # contributions are at the same index in the extensions list of lists. 195 index = self._providers.index(provider) 196 197 # Does the provider contribute any extensions to an extension point 198 # that has already been accessed? 199 for extension_point_id, extensions in self._extensions.items(): 200 old = extensions[index] 201 202 # We only need fire an event for this extension point if the 203 # provider contributed any extensions. 204 if len(old) > 0: 205 offset = sum(map(len, extensions[:index])) 206 refs = self._get_listener_refs(extension_point_id) 207 events[extension_point_id] = (refs, old[:], offset) 208 209 del extensions[index] 210 211 return events 212 213 def _remove_provider_extension_points(self, provider, events): 214 """ Remove a provider's extension points from the registry. """ 215 216 for extension_point in provider.get_extension_points(): 217 # Remove the extension point. 218 del self._extension_points[extension_point.id] 219 220 return 221 222 ########################################################################### 223 # Private interface. 224 ########################################################################### 225 226 #### Trait change handlers ################################################ 227 228 @on_trait_change('_providers:extension_point_changed') 229 def _providers_extension_point_changed(self, obj, trait_name, old, event): 230 """ Dynamic trait change handler. """ 231 232 logger.debug('provider <%s> extension point changed', obj) 233 234 extension_point_id = event.extension_point_id 235 236 # If the extension point has not yet been accessed then we don't fire a 237 # changed event. 238 # 239 # This is because we only access extension points lazily and so we 240 # can't tell what has actually changed because we have nothing to 241 # compare it to! 242 if not extension_point_id in self._extensions: 243 return 244 245 # This is a list of lists where each inner list contains the 246 # contributions made to the extension point by a single provider. 247 # 248 # fixme: This causes a problem if the extension point has not yet been 249 # accessed! The tricky thing is that if it hasn't been accessed yet 250 # how do we know what has changed?!? Maybe we should just return an 251 # empty list instead of barfing! 252 extensions = self._extensions[extension_point_id] 253 254 # Find the index of the provider in the provider list. Its 255 # contributions are at the same index in the extensions list of lists. 256 provider_index = self._providers.index(obj) 257 258 # Get the updated list from the provider. 259 extensions[provider_index] = obj.get_extensions(extension_point_id) 260 261 # Find where the provider's contributions are in the whole 'list'. 262 offset = sum(map(len, extensions[:provider_index])) 263 264 # Translate the event index from one that refers to the list of 265 # contributions from the provider, to the list of contributions from 266 # all providers. 267 index = self._translate_index(event.index, offset) 268 269 # Find out who is listening. 270 refs = self._get_listener_refs(extension_point_id) 271 272 # Let any listeners know that the extensions have been added. 273 self._call_listeners( 274 refs, extension_point_id, event.added, event.removed, index 275 ) 276 277 return 278 279 #### Methods ############################################################## 280 281 def _initialize_extensions(self, extension_point_id): 282 """ Initialize the extensions to an extension point. """ 283 284 # We store the extensions as a list of lists, with each inner list 285 # containing the contributions from a single provider. 286 extensions = [] 287 for provider in self._providers: 288 extensions.append(provider.get_extensions(extension_point_id)[:]) 289 290 logger.debug('extensions to <%s> <%s>', extension_point_id, extensions) 291 292 return extensions 293 294 def _translate_index(self, index, offset): 295 """ Translate an event index by the given offset. """ 296 297 if isinstance(index, slice): 298 index = slice(index.start+offset, index.stop+offset, index.step) 299 300 else: 301 index = index + offset 302 303 return index 304 305#### EOF ###################################################################### 306