1"""
2Extension Mechanism for nodes.
3
4The extension mechanism makes it possible to dynamically add class attributes,
5especially methods, for specific features to node classes
6(e.g. nodes need a _fork and _join method for parallelization).
7It is also possible for users to define new extensions to provide new
8functionality for MDP nodes without having to modify any MDP code.
9
10Without the extension mechanism extending nodes would be done by inheritance,
11which is fine unless one wants to use multiple inheritance at the same time
12(requiring multiple inheritance for every combination of extensions one wants
13to use). The extension mechanism does not depend on inheritance, instead it
14adds the methods to the node classes dynamically at runtime. This makes it
15possible to activate extensions just when they are needed, reducing the risk
16of interference between different extensions.
17
18However, since the extension mechanism provides a special Metaclass it is
19still possible to define the extension nodes as classes derived from nodes.
20This keeps the code readable and is compatible with automatic code checkers
21(like the background pylint checks in the Eclipse IDE with PyDev).
22"""
23from __future__ import print_function
24from builtins import str
25from builtins import object
26
27from mdp import MDPException, NodeMetaclass
28from future.utils import with_metaclass
29
30# TODO: Register the node instances as well?
31#    This would allow instance initialization when an extension is activated.
32#    Implementing this should not be too hard via the metclass.
33
34# TODO: Add warning about overriding public methods with respect to
35#    the docstring wrappers?
36
37# TODO: in the future could use ABC's to register nodes with extension nodes
38
39
40# name prefix used for the original attributes when they are shadowed
41ORIGINAL_ATTR_PREFIX = "_non_extension_"
42# prefix used to store the current extension name for an attribute,
43# the value stored in this attribute is the extension name
44_EXTENSION_ATTR_PREFIX = "_extension_for_"
45# list of attribute names that are not affected by extensions,
46_NON_EXTENSION_ATTRIBUTES = ["__module__", "__doc__", "extension_name"]
47
48# keys under which the global activation and deactivation functions
49# for extensions can be stored in the extension registry
50_SETUP_FUNC_ATTR = "_extension_setup"
51_TEARDOWN_FUNC_ATTR = "_extension_teardown"
52
53# dict of dicts of dicts, contains a key for each extension,
54# the inner dict maps the node types to their extension node,
55# the innermost dict then maps attribute names to values
56# (e.g. a method name to the actual function)
57#
58# For each extension there are also the special _SETUP_FUNC_ATTR and
59# _TEARDOWN_FUNC_ATTR keys.
60_extensions = dict()
61# set containing the names of the currently activated extensions
62_active_extensions = set()
63
64
65class ExtensionException(MDPException):
66    """Base class for extension related exceptions."""
67    pass
68
69
70def _register_attribute(ext_name, node_cls, attr_name, attr_value):
71    """Register an attribute as an extension attribute.
72
73    ext_name -- String with the name of the extension.
74    node_cls -- Node class for which the method should be registered.
75    """
76    _extensions[ext_name][node_cls][attr_name] = attr_value
77
78
79def extension_method(extension_name, node_cls, method_name=None):
80    """Returns a decorator to register a function as an extension method.
81
82    :Parameters:
83      extension_name
84        String with the name of the extension.
85      node_cls
86        Node class for which the method should be registered.
87      method_name
88        Name of the extension method (default value is ``None``).
89
90        If no value is provided then the name of the function is used.
91
92    Note that it is possible to directly call other extension functions, call
93    extension methods in other node classes or to use super in the normal way
94    (the function will be called as a method of the node class).
95    """
96    def register_function(func):
97        _method_name = method_name
98        if not _method_name:
99            _method_name = func.__name__
100        if not extension_name in _extensions:
101            # creation of a new extension, add entry in dict
102            _extensions[extension_name] = dict()
103        if not node_cls in _extensions[extension_name]:
104            # register this node
105            _extensions[extension_name][node_cls] = dict()
106        _register_attribute(extension_name, node_cls, _method_name, func)
107        return func
108    return register_function
109
110
111def extension_setup(extension_name):
112    """Returns a decorator to register a setup function for an extension.
113
114    :Parameters:
115      extension_name
116        String with the name of the extension.
117
118    The decorated function will be called when the extension is activated.
119
120    Note that there is also the extension_teardown decorator, which should
121    probably defined as well if there is a setup procedure.
122    """
123    def register_setup_function(func):
124        if not extension_name in _extensions:
125            # creation of a new extension, add entry in dict
126            _extensions[extension_name] = dict()
127        if _SETUP_FUNC_ATTR in _extensions[extension_name]:
128            err = "There is already a setup function for this extension."
129            raise ExtensionException(err)
130        _extensions[extension_name][_SETUP_FUNC_ATTR] = func
131        return func
132    return register_setup_function
133
134
135def extension_teardown(extension_name):
136    """Returns a decorator to register a teardown function for an extension.
137
138    :Parameters:
139      extension_name
140        String with the name of the extension.
141
142    The decorated function will be called when the extension is deactivated.
143    """
144    def register_teardown_function(func):
145        if not extension_name in _extensions:
146            # creation of a new extension, add entry in dict
147            _extensions[extension_name] = dict()
148        if _TEARDOWN_FUNC_ATTR in _extensions[extension_name]:
149            err = "There is already a teardown function for this extension."
150            raise ExtensionException(err)
151        _extensions[extension_name][_TEARDOWN_FUNC_ATTR] = func
152        return func
153    return register_teardown_function
154
155
156class ExtensionNodeMetaclass(NodeMetaclass):
157    """This is the metaclass for node extension superclasses.
158
159    It takes care of registering extensions and the attributes in the
160    extension.
161    """
162
163    def __new__(cls, classname, bases, members):
164        """Create new node classes and register extensions.
165
166        If a concrete extension node is created then a corresponding mixin
167        class is automatically created and registered.
168        """
169        if classname == "ExtensionNode":
170            # initial creation of ExtensionNode class
171            return super(ExtensionNodeMetaclass, cls).__new__(cls, classname,
172                                                              bases, members)
173        # check if this is a new extension definition,
174        # in that case this node is directly derived from ExtensionNode
175        if ExtensionNode in bases:
176            ext_name = members["extension_name"]
177            if not ext_name:
178                err = "No extension name has been specified."
179                raise ExtensionException(err)
180            if ext_name not in _extensions:
181                # creation of a new extension, add entry in dict
182                _extensions[ext_name] = dict()
183            else:
184                err = ("An extension with the name '" + ext_name +
185                       "' has already been registered.")
186                raise ExtensionException(err)
187        # find the node that this extension node belongs to
188        base_node_cls = None
189        for base in bases:
190            if type(base) is not ExtensionNodeMetaclass:
191                if base_node_cls is None:
192                    base_node_cls = base
193                else:
194                    err = ("Extension node derived from multiple "
195                           "normal nodes.")
196                    raise ExtensionException(err)
197        if base_node_cls is None:
198            # This new extension is not directly derived from another class,
199            # so there is nothing to register (no default implementation).
200            # We disable the doc method extension mechanism as this class
201            # is not a node subclass and adding methods (e.g. _execute) would
202            # cause problems.
203            cls.DOC_METHODS = []
204            return super(ExtensionNodeMetaclass, cls).__new__(cls, classname,
205                                                              bases, members)
206        ext_node_cls = super(ExtensionNodeMetaclass, cls).__new__(
207                                                cls, classname, bases, members)
208        ext_name = ext_node_cls.extension_name
209        if not base_node_cls in _extensions[ext_name]:
210            # register the base node
211            _extensions[ext_name][base_node_cls] = dict()
212        # Register methods from extension class hierarchy: iterate MRO in
213        # reverse order and register all attributes starting from the
214        # classes which are subclasses from ExtensionNode.
215        extension_subtree = False
216        for base in reversed(ext_node_cls.__mro__):
217            # make sure we only inject methods in classes which have
218            # ExtensionNode as superclass
219            if extension_subtree and ExtensionNode in base.__mro__:
220                for attr_name, attr_value in list(base.__dict__.items()):
221                    if attr_name not in _NON_EXTENSION_ATTRIBUTES:
222                        # check if this attribute has not already been
223                        # extended in one of the base classes
224                        already_active = False
225                        for bb in ext_node_cls.__mro__:
226                            if (bb in _extensions[ext_name] and
227                            attr_name in _extensions[ext_name][bb] and
228                            _extensions[ext_name][bb][attr_name] == attr_value):
229                                already_active = True
230                        # only register if not yet active
231                        if not already_active:
232                            _register_attribute(ext_name, base_node_cls,
233                                                attr_name, attr_value)
234            if base == ExtensionNode:
235                extension_subtree = True
236        return ext_node_cls
237
238
239class ExtensionNode(with_metaclass(ExtensionNodeMetaclass, object)):
240    """Base class for extensions nodes.
241
242    A new extension node class should override the _extension_name.
243    The concrete node implementations are then derived from this extension
244    node class.
245
246    To call an instance method from a parent class you have multiple options:
247
248    - use super, but with the normal node class, e.g.:
249
250      >>>  super(mdp.nodes.SFA2Node, self).method()      # doctest: +SKIP
251
252      Here SFA2Node was given instead of the extension node class for the
253      SFA2Node.
254
255      If the extensions node class is used directly (without the extension
256      mechanism) this can cause problems. In that case you have to be
257      careful about the inheritance order and the effect on the MRO.
258
259    - call it explicitly using the __func__ attribute [python version < 3]:
260
261      >>> parent_class.method.__func__(self)             # doctest: +SKIP
262
263      or [python version >=3]:
264
265      >>> parent_class.method(self)                      # doctest: +SKIP
266
267    To call the original (pre-extension) method in the same class use you
268    simply prefix the method name with '_non_extension_' (this is the value
269    of the `ORIGINAL_ATTR_PREFIX` constant in this module).
270    """
271    # override this name in a concrete extension node base class
272    extension_name = None
273
274
275def get_extensions():
276    """Return a dictionary currently registered extensions.
277
278    Note that this is not a copy, so if you change anything in this dict
279    the whole extension mechanism will be affected. If you just want the
280    names of the available extensions use get_extensions().keys().
281    """
282    return _extensions
283
284def get_active_extensions():
285    """Returns a list with the names of the currently activated extensions."""
286    # use copy to protect the original set, also important if the return
287    # value is used in a for-loop (see deactivate_extensions function)
288    return list(_active_extensions)
289
290def activate_extension(extension_name, verbose=False):
291    """Activate the extension by injecting the extension methods."""
292    if extension_name not in list(_extensions.keys()):
293        err = "Unknown extension name: %s"%str(extension_name)
294        raise ExtensionException(err)
295    if extension_name in _active_extensions:
296        if verbose:
297            print('Extension %s is already active!' % extension_name)
298        return
299    _active_extensions.add(extension_name)
300    try:
301        if _SETUP_FUNC_ATTR in _extensions[extension_name]:
302            _extensions[extension_name][_SETUP_FUNC_ATTR]()
303        for node_cls, attributes in list(_extensions[extension_name].items()):
304            if node_cls == _SETUP_FUNC_ATTR or node_cls == _TEARDOWN_FUNC_ATTR:
305                continue
306            for attr_name, attr_value in list(attributes.items()):
307                if verbose:
308                    print ("extension %s: adding %s to %s" %
309                           (extension_name, attr_name, node_cls.__name__))
310                ## store the original attribute / make it available
311                ext_attr_name = _EXTENSION_ATTR_PREFIX + attr_name
312                if attr_name in dir(node_cls):
313                    if ext_attr_name in node_cls.__dict__:
314                        # two extensions override the same attribute
315                        err = ("Name collision for attribute '" +
316                               attr_name + "' between extension '" +
317                               getattr(node_cls, ext_attr_name)
318                               + "' and newly activated extension '" +
319                               extension_name + "'.")
320                        raise ExtensionException(err)
321                    # only overwrite the attribute if the extension is not
322                    # yet active on this class or its superclasses
323                    if ext_attr_name not in dir(node_cls):
324                        original_attr = getattr(node_cls, attr_name)
325                        if verbose:
326                            print ("extension %s: overwriting %s in %s" %
327                                (extension_name, attr_name, node_cls.__name__))
328                        setattr(node_cls, ORIGINAL_ATTR_PREFIX + attr_name,
329                                original_attr)
330                setattr(node_cls, attr_name, attr_value)
331                # store to which extension this attribute belongs, this is also
332                # used as a flag that this is an extension attribute
333                setattr(node_cls, ext_attr_name, extension_name)
334    except Exception:
335        # make sure that an incomplete activation is reverted
336        deactivate_extension(extension_name)
337        raise
338
339def deactivate_extension(extension_name, verbose=False):
340    """Deactivate the extension by removing the injected methods."""
341    if extension_name not in list(_extensions.keys()):
342        err = "Unknown extension name: " + str(extension_name)
343        raise ExtensionException(err)
344    if extension_name not in _active_extensions:
345        return
346    for node_cls, attributes in list(_extensions[extension_name].items()):
347        if node_cls == _SETUP_FUNC_ATTR or node_cls == _TEARDOWN_FUNC_ATTR:
348            continue
349        for attr_name in list(attributes.keys()):
350            original_name = ORIGINAL_ATTR_PREFIX + attr_name
351            if verbose:
352                print ("extension %s: removing %s from %s" %
353                       (extension_name, attr_name, node_cls.__name__))
354            if original_name in node_cls.__dict__:
355                # restore the original attribute
356                if verbose:
357                    print ("extension %s: restoring %s in %s" %
358                           (extension_name, attr_name, node_cls.__name__))
359                delattr(node_cls, attr_name)
360                original_attr = getattr(node_cls, original_name)
361                # Check if the attribute is defined by one of the super
362                # classes and test if the overwritten method is not that
363                # method, otherwise we would inject unwanted methods.
364                # Note: '==' tests identity for .__func__ and .__self__,
365                #    but .im_class does not matter in Python 2.6.
366                if all([getattr(x, attr_name, None) !=
367                           original_attr for x in node_cls.__mro__[1:]]):
368                    setattr(node_cls, attr_name, original_attr)
369                delattr(node_cls, original_name)
370            else:
371                try:
372                    # no original attribute to restore, so simply delete
373                    # might be missing if the activation failed
374                    delattr(node_cls, attr_name)
375                except AttributeError:
376                    pass
377            try:
378                # might be missing if the activation failed
379                delattr(node_cls, _EXTENSION_ATTR_PREFIX + attr_name)
380            except AttributeError:
381                pass
382    if _TEARDOWN_FUNC_ATTR in _extensions[extension_name]:
383        _extensions[extension_name][_TEARDOWN_FUNC_ATTR]()
384    _active_extensions.remove(extension_name)
385
386def activate_extensions(extension_names, verbose=False):
387    """Activate all the extensions for the given names.
388
389    extension_names -- Sequence of extension names.
390    """
391    try:
392        for extension_name in extension_names:
393            activate_extension(extension_name, verbose=verbose)
394    except:
395        # if something goes wrong deactivate all, otherwise we might be
396        # in an inconsistent state (e.g. methods for active extensions might
397        # have been removed)
398        deactivate_extensions(get_active_extensions())
399        raise
400
401def deactivate_extensions(extension_names, verbose=False):
402    """Deactivate all the extensions for the given names.
403
404    extension_names -- Sequence of extension names.
405    """
406    for extension_name in extension_names:
407        deactivate_extension(extension_name, verbose=verbose)
408
409# TODO: add check that only extensions are deactivated that were
410#    originally activcated by this extension (same in context manager)
411#    also add test for this
412def with_extension(extension_name):
413    """Return a wrapper function to activate and deactivate the extension.
414
415    This function is intended to be used with the decorator syntax.
416
417    The deactivation happens only if the extension was activated by
418    the decorator (not if it was already active before). So this
419    decorator ensures that the extensions is active and prevents
420    unintended side effects.
421
422    If the generated function is a generator, the extension will be in
423    effect only when the generator object is created (that is when the
424    function is called, but its body is not actually immediately
425    executed). When the function body is executed (after ``next`` is
426    called on the generator object), the extension might not be in
427    effect anymore. Therefore, it is better to use the `extension`
428    context manager with a generator function.
429    """
430    def decorator(func):
431        def wrapper(*args, **kwargs):
432            # make sure that we don't deactive and extension that was
433            # not activated by the decorator (would be a strange sideeffect)
434            if extension_name not in get_active_extensions():
435                try:
436                    activate_extension(extension_name)
437                    result = func(*args, **kwargs)
438                finally:
439                    deactivate_extension(extension_name)
440            else:
441                result = func(*args, **kwargs)
442            return result
443        # now make sure that docstring and signature match the original
444        func_info = NodeMetaclass._function_infodict(func)
445        return NodeMetaclass._wrap_function(wrapper, func_info)
446    return decorator
447
448class extension(object):
449    """Context manager for MDP extension.
450
451    This allows you to use extensions using a ``with`` statement, as in:
452
453    >>> with mdp.extension('extension_name'):
454    ...     # 'node' is executed with the extension activated
455    ...     node.execute(x)
456
457    It is also possible to activate multiple extensions at once:
458
459    >>> with mdp.extension(['ext1', 'ext2']):
460    ...     # 'node' is executed with the two extensions activated
461    ...     node.execute(x)
462
463    The deactivation at the end happens only for the extensions that were
464    activated by this context manager (not for those that were already active
465    when the context was entered). This prevents unintended side effects.
466    """
467
468    def __init__(self, ext_names):
469        if isinstance(ext_names, __builtins__['str']):
470            ext_names = [ext_names]
471        self.ext_names = ext_names
472        self.deactivate_exts = []
473
474    def __enter__(self):
475        already_active = get_active_extensions()
476        self.deactivate_exts = [ext_name for ext_name in self.ext_names
477                                if ext_name not in already_active]
478        activate_extensions(self.ext_names)
479
480    def __exit__(self, type, value, traceback):
481        deactivate_extensions(self.deactivate_exts)
482