1""" 2Extension Mechanism for nodes. 3 4The extension mechanism makes it possible to dynamically add class attributes, 5especially methods, for specific features to node classes 6(e.g. nodes need a _fork and _join method for parallelization). 7It is also possible for users to define new extensions to provide new 8functionality for MDP nodes without having to modify any MDP code. 9 10Without the extension mechanism extending nodes would be done by inheritance, 11which is fine unless one wants to use multiple inheritance at the same time 12(requiring multiple inheritance for every combination of extensions one wants 13to use). The extension mechanism does not depend on inheritance, instead it 14adds the methods to the node classes dynamically at runtime. This makes it 15possible to activate extensions just when they are needed, reducing the risk 16of interference between different extensions. 17 18However, since the extension mechanism provides a special Metaclass it is 19still possible to define the extension nodes as classes derived from nodes. 20This keeps the code readable and is compatible with automatic code checkers 21(like the background pylint checks in the Eclipse IDE with PyDev). 22""" 23from __future__ import print_function 24from builtins import str 25from builtins import object 26 27from mdp import MDPException, NodeMetaclass 28from future.utils import with_metaclass 29 30# TODO: Register the node instances as well? 31# This would allow instance initialization when an extension is activated. 32# Implementing this should not be too hard via the metclass. 33 34# TODO: Add warning about overriding public methods with respect to 35# the docstring wrappers? 36 37# TODO: in the future could use ABC's to register nodes with extension nodes 38 39 40# name prefix used for the original attributes when they are shadowed 41ORIGINAL_ATTR_PREFIX = "_non_extension_" 42# prefix used to store the current extension name for an attribute, 43# the value stored in this attribute is the extension name 44_EXTENSION_ATTR_PREFIX = "_extension_for_" 45# list of attribute names that are not affected by extensions, 46_NON_EXTENSION_ATTRIBUTES = ["__module__", "__doc__", "extension_name"] 47 48# keys under which the global activation and deactivation functions 49# for extensions can be stored in the extension registry 50_SETUP_FUNC_ATTR = "_extension_setup" 51_TEARDOWN_FUNC_ATTR = "_extension_teardown" 52 53# dict of dicts of dicts, contains a key for each extension, 54# the inner dict maps the node types to their extension node, 55# the innermost dict then maps attribute names to values 56# (e.g. a method name to the actual function) 57# 58# For each extension there are also the special _SETUP_FUNC_ATTR and 59# _TEARDOWN_FUNC_ATTR keys. 60_extensions = dict() 61# set containing the names of the currently activated extensions 62_active_extensions = set() 63 64 65class ExtensionException(MDPException): 66 """Base class for extension related exceptions.""" 67 pass 68 69 70def _register_attribute(ext_name, node_cls, attr_name, attr_value): 71 """Register an attribute as an extension attribute. 72 73 ext_name -- String with the name of the extension. 74 node_cls -- Node class for which the method should be registered. 75 """ 76 _extensions[ext_name][node_cls][attr_name] = attr_value 77 78 79def extension_method(extension_name, node_cls, method_name=None): 80 """Returns a decorator to register a function as an extension method. 81 82 :Parameters: 83 extension_name 84 String with the name of the extension. 85 node_cls 86 Node class for which the method should be registered. 87 method_name 88 Name of the extension method (default value is ``None``). 89 90 If no value is provided then the name of the function is used. 91 92 Note that it is possible to directly call other extension functions, call 93 extension methods in other node classes or to use super in the normal way 94 (the function will be called as a method of the node class). 95 """ 96 def register_function(func): 97 _method_name = method_name 98 if not _method_name: 99 _method_name = func.__name__ 100 if not extension_name in _extensions: 101 # creation of a new extension, add entry in dict 102 _extensions[extension_name] = dict() 103 if not node_cls in _extensions[extension_name]: 104 # register this node 105 _extensions[extension_name][node_cls] = dict() 106 _register_attribute(extension_name, node_cls, _method_name, func) 107 return func 108 return register_function 109 110 111def extension_setup(extension_name): 112 """Returns a decorator to register a setup function for an extension. 113 114 :Parameters: 115 extension_name 116 String with the name of the extension. 117 118 The decorated function will be called when the extension is activated. 119 120 Note that there is also the extension_teardown decorator, which should 121 probably defined as well if there is a setup procedure. 122 """ 123 def register_setup_function(func): 124 if not extension_name in _extensions: 125 # creation of a new extension, add entry in dict 126 _extensions[extension_name] = dict() 127 if _SETUP_FUNC_ATTR in _extensions[extension_name]: 128 err = "There is already a setup function for this extension." 129 raise ExtensionException(err) 130 _extensions[extension_name][_SETUP_FUNC_ATTR] = func 131 return func 132 return register_setup_function 133 134 135def extension_teardown(extension_name): 136 """Returns a decorator to register a teardown function for an extension. 137 138 :Parameters: 139 extension_name 140 String with the name of the extension. 141 142 The decorated function will be called when the extension is deactivated. 143 """ 144 def register_teardown_function(func): 145 if not extension_name in _extensions: 146 # creation of a new extension, add entry in dict 147 _extensions[extension_name] = dict() 148 if _TEARDOWN_FUNC_ATTR in _extensions[extension_name]: 149 err = "There is already a teardown function for this extension." 150 raise ExtensionException(err) 151 _extensions[extension_name][_TEARDOWN_FUNC_ATTR] = func 152 return func 153 return register_teardown_function 154 155 156class ExtensionNodeMetaclass(NodeMetaclass): 157 """This is the metaclass for node extension superclasses. 158 159 It takes care of registering extensions and the attributes in the 160 extension. 161 """ 162 163 def __new__(cls, classname, bases, members): 164 """Create new node classes and register extensions. 165 166 If a concrete extension node is created then a corresponding mixin 167 class is automatically created and registered. 168 """ 169 if classname == "ExtensionNode": 170 # initial creation of ExtensionNode class 171 return super(ExtensionNodeMetaclass, cls).__new__(cls, classname, 172 bases, members) 173 # check if this is a new extension definition, 174 # in that case this node is directly derived from ExtensionNode 175 if ExtensionNode in bases: 176 ext_name = members["extension_name"] 177 if not ext_name: 178 err = "No extension name has been specified." 179 raise ExtensionException(err) 180 if ext_name not in _extensions: 181 # creation of a new extension, add entry in dict 182 _extensions[ext_name] = dict() 183 else: 184 err = ("An extension with the name '" + ext_name + 185 "' has already been registered.") 186 raise ExtensionException(err) 187 # find the node that this extension node belongs to 188 base_node_cls = None 189 for base in bases: 190 if type(base) is not ExtensionNodeMetaclass: 191 if base_node_cls is None: 192 base_node_cls = base 193 else: 194 err = ("Extension node derived from multiple " 195 "normal nodes.") 196 raise ExtensionException(err) 197 if base_node_cls is None: 198 # This new extension is not directly derived from another class, 199 # so there is nothing to register (no default implementation). 200 # We disable the doc method extension mechanism as this class 201 # is not a node subclass and adding methods (e.g. _execute) would 202 # cause problems. 203 cls.DOC_METHODS = [] 204 return super(ExtensionNodeMetaclass, cls).__new__(cls, classname, 205 bases, members) 206 ext_node_cls = super(ExtensionNodeMetaclass, cls).__new__( 207 cls, classname, bases, members) 208 ext_name = ext_node_cls.extension_name 209 if not base_node_cls in _extensions[ext_name]: 210 # register the base node 211 _extensions[ext_name][base_node_cls] = dict() 212 # Register methods from extension class hierarchy: iterate MRO in 213 # reverse order and register all attributes starting from the 214 # classes which are subclasses from ExtensionNode. 215 extension_subtree = False 216 for base in reversed(ext_node_cls.__mro__): 217 # make sure we only inject methods in classes which have 218 # ExtensionNode as superclass 219 if extension_subtree and ExtensionNode in base.__mro__: 220 for attr_name, attr_value in list(base.__dict__.items()): 221 if attr_name not in _NON_EXTENSION_ATTRIBUTES: 222 # check if this attribute has not already been 223 # extended in one of the base classes 224 already_active = False 225 for bb in ext_node_cls.__mro__: 226 if (bb in _extensions[ext_name] and 227 attr_name in _extensions[ext_name][bb] and 228 _extensions[ext_name][bb][attr_name] == attr_value): 229 already_active = True 230 # only register if not yet active 231 if not already_active: 232 _register_attribute(ext_name, base_node_cls, 233 attr_name, attr_value) 234 if base == ExtensionNode: 235 extension_subtree = True 236 return ext_node_cls 237 238 239class ExtensionNode(with_metaclass(ExtensionNodeMetaclass, object)): 240 """Base class for extensions nodes. 241 242 A new extension node class should override the _extension_name. 243 The concrete node implementations are then derived from this extension 244 node class. 245 246 To call an instance method from a parent class you have multiple options: 247 248 - use super, but with the normal node class, e.g.: 249 250 >>> super(mdp.nodes.SFA2Node, self).method() # doctest: +SKIP 251 252 Here SFA2Node was given instead of the extension node class for the 253 SFA2Node. 254 255 If the extensions node class is used directly (without the extension 256 mechanism) this can cause problems. In that case you have to be 257 careful about the inheritance order and the effect on the MRO. 258 259 - call it explicitly using the __func__ attribute [python version < 3]: 260 261 >>> parent_class.method.__func__(self) # doctest: +SKIP 262 263 or [python version >=3]: 264 265 >>> parent_class.method(self) # doctest: +SKIP 266 267 To call the original (pre-extension) method in the same class use you 268 simply prefix the method name with '_non_extension_' (this is the value 269 of the `ORIGINAL_ATTR_PREFIX` constant in this module). 270 """ 271 # override this name in a concrete extension node base class 272 extension_name = None 273 274 275def get_extensions(): 276 """Return a dictionary currently registered extensions. 277 278 Note that this is not a copy, so if you change anything in this dict 279 the whole extension mechanism will be affected. If you just want the 280 names of the available extensions use get_extensions().keys(). 281 """ 282 return _extensions 283 284def get_active_extensions(): 285 """Returns a list with the names of the currently activated extensions.""" 286 # use copy to protect the original set, also important if the return 287 # value is used in a for-loop (see deactivate_extensions function) 288 return list(_active_extensions) 289 290def activate_extension(extension_name, verbose=False): 291 """Activate the extension by injecting the extension methods.""" 292 if extension_name not in list(_extensions.keys()): 293 err = "Unknown extension name: %s"%str(extension_name) 294 raise ExtensionException(err) 295 if extension_name in _active_extensions: 296 if verbose: 297 print('Extension %s is already active!' % extension_name) 298 return 299 _active_extensions.add(extension_name) 300 try: 301 if _SETUP_FUNC_ATTR in _extensions[extension_name]: 302 _extensions[extension_name][_SETUP_FUNC_ATTR]() 303 for node_cls, attributes in list(_extensions[extension_name].items()): 304 if node_cls == _SETUP_FUNC_ATTR or node_cls == _TEARDOWN_FUNC_ATTR: 305 continue 306 for attr_name, attr_value in list(attributes.items()): 307 if verbose: 308 print ("extension %s: adding %s to %s" % 309 (extension_name, attr_name, node_cls.__name__)) 310 ## store the original attribute / make it available 311 ext_attr_name = _EXTENSION_ATTR_PREFIX + attr_name 312 if attr_name in dir(node_cls): 313 if ext_attr_name in node_cls.__dict__: 314 # two extensions override the same attribute 315 err = ("Name collision for attribute '" + 316 attr_name + "' between extension '" + 317 getattr(node_cls, ext_attr_name) 318 + "' and newly activated extension '" + 319 extension_name + "'.") 320 raise ExtensionException(err) 321 # only overwrite the attribute if the extension is not 322 # yet active on this class or its superclasses 323 if ext_attr_name not in dir(node_cls): 324 original_attr = getattr(node_cls, attr_name) 325 if verbose: 326 print ("extension %s: overwriting %s in %s" % 327 (extension_name, attr_name, node_cls.__name__)) 328 setattr(node_cls, ORIGINAL_ATTR_PREFIX + attr_name, 329 original_attr) 330 setattr(node_cls, attr_name, attr_value) 331 # store to which extension this attribute belongs, this is also 332 # used as a flag that this is an extension attribute 333 setattr(node_cls, ext_attr_name, extension_name) 334 except Exception: 335 # make sure that an incomplete activation is reverted 336 deactivate_extension(extension_name) 337 raise 338 339def deactivate_extension(extension_name, verbose=False): 340 """Deactivate the extension by removing the injected methods.""" 341 if extension_name not in list(_extensions.keys()): 342 err = "Unknown extension name: " + str(extension_name) 343 raise ExtensionException(err) 344 if extension_name not in _active_extensions: 345 return 346 for node_cls, attributes in list(_extensions[extension_name].items()): 347 if node_cls == _SETUP_FUNC_ATTR or node_cls == _TEARDOWN_FUNC_ATTR: 348 continue 349 for attr_name in list(attributes.keys()): 350 original_name = ORIGINAL_ATTR_PREFIX + attr_name 351 if verbose: 352 print ("extension %s: removing %s from %s" % 353 (extension_name, attr_name, node_cls.__name__)) 354 if original_name in node_cls.__dict__: 355 # restore the original attribute 356 if verbose: 357 print ("extension %s: restoring %s in %s" % 358 (extension_name, attr_name, node_cls.__name__)) 359 delattr(node_cls, attr_name) 360 original_attr = getattr(node_cls, original_name) 361 # Check if the attribute is defined by one of the super 362 # classes and test if the overwritten method is not that 363 # method, otherwise we would inject unwanted methods. 364 # Note: '==' tests identity for .__func__ and .__self__, 365 # but .im_class does not matter in Python 2.6. 366 if all([getattr(x, attr_name, None) != 367 original_attr for x in node_cls.__mro__[1:]]): 368 setattr(node_cls, attr_name, original_attr) 369 delattr(node_cls, original_name) 370 else: 371 try: 372 # no original attribute to restore, so simply delete 373 # might be missing if the activation failed 374 delattr(node_cls, attr_name) 375 except AttributeError: 376 pass 377 try: 378 # might be missing if the activation failed 379 delattr(node_cls, _EXTENSION_ATTR_PREFIX + attr_name) 380 except AttributeError: 381 pass 382 if _TEARDOWN_FUNC_ATTR in _extensions[extension_name]: 383 _extensions[extension_name][_TEARDOWN_FUNC_ATTR]() 384 _active_extensions.remove(extension_name) 385 386def activate_extensions(extension_names, verbose=False): 387 """Activate all the extensions for the given names. 388 389 extension_names -- Sequence of extension names. 390 """ 391 try: 392 for extension_name in extension_names: 393 activate_extension(extension_name, verbose=verbose) 394 except: 395 # if something goes wrong deactivate all, otherwise we might be 396 # in an inconsistent state (e.g. methods for active extensions might 397 # have been removed) 398 deactivate_extensions(get_active_extensions()) 399 raise 400 401def deactivate_extensions(extension_names, verbose=False): 402 """Deactivate all the extensions for the given names. 403 404 extension_names -- Sequence of extension names. 405 """ 406 for extension_name in extension_names: 407 deactivate_extension(extension_name, verbose=verbose) 408 409# TODO: add check that only extensions are deactivated that were 410# originally activcated by this extension (same in context manager) 411# also add test for this 412def with_extension(extension_name): 413 """Return a wrapper function to activate and deactivate the extension. 414 415 This function is intended to be used with the decorator syntax. 416 417 The deactivation happens only if the extension was activated by 418 the decorator (not if it was already active before). So this 419 decorator ensures that the extensions is active and prevents 420 unintended side effects. 421 422 If the generated function is a generator, the extension will be in 423 effect only when the generator object is created (that is when the 424 function is called, but its body is not actually immediately 425 executed). When the function body is executed (after ``next`` is 426 called on the generator object), the extension might not be in 427 effect anymore. Therefore, it is better to use the `extension` 428 context manager with a generator function. 429 """ 430 def decorator(func): 431 def wrapper(*args, **kwargs): 432 # make sure that we don't deactive and extension that was 433 # not activated by the decorator (would be a strange sideeffect) 434 if extension_name not in get_active_extensions(): 435 try: 436 activate_extension(extension_name) 437 result = func(*args, **kwargs) 438 finally: 439 deactivate_extension(extension_name) 440 else: 441 result = func(*args, **kwargs) 442 return result 443 # now make sure that docstring and signature match the original 444 func_info = NodeMetaclass._function_infodict(func) 445 return NodeMetaclass._wrap_function(wrapper, func_info) 446 return decorator 447 448class extension(object): 449 """Context manager for MDP extension. 450 451 This allows you to use extensions using a ``with`` statement, as in: 452 453 >>> with mdp.extension('extension_name'): 454 ... # 'node' is executed with the extension activated 455 ... node.execute(x) 456 457 It is also possible to activate multiple extensions at once: 458 459 >>> with mdp.extension(['ext1', 'ext2']): 460 ... # 'node' is executed with the two extensions activated 461 ... node.execute(x) 462 463 The deactivation at the end happens only for the extensions that were 464 activated by this context manager (not for those that were already active 465 when the context was entered). This prevents unintended side effects. 466 """ 467 468 def __init__(self, ext_names): 469 if isinstance(ext_names, __builtins__['str']): 470 ext_names = [ext_names] 471 self.ext_names = ext_names 472 self.deactivate_exts = [] 473 474 def __enter__(self): 475 already_active = get_active_extensions() 476 self.deactivate_exts = [ext_name for ext_name in self.ext_names 477 if ext_name not in already_active] 478 activate_extensions(self.ext_names) 479 480 def __exit__(self, type, value, traceback): 481 deactivate_extensions(self.deactivate_exts) 482