1__docformat__ = "restructuredtext en"
2
3from routines import (timediff, refcast, scast, rotate, random_rot,
4                      permute, symrand, norm2, cov2,
5                      mult_diag, comb, sqrtm, get_dtypes, nongeneral_svd,
6                      hermitian, cov_maxima,
7                      lrep, rrep, irep, orthogonal_permutations,
8                      izip_stretched,
9                      weighted_choice, bool_to_sign, sign_to_bool, gabor,
10                      invert_exp_funcs2)
11try:
12    from collections import OrderedDict
13except ImportError:
14    ## Getting an Ordered Dict for Python < 2.7
15    from _ordered_dict import OrderedDict
16
17try:
18    from tempfile import TemporaryDirectory
19except ImportError:
20    from temporarydir import TemporaryDirectory
21
22from introspection import dig_node, get_node_size, get_node_size_str
23from quad_forms import QuadraticForm, QuadraticFormException
24from covariance import (CovarianceMatrix, DelayCovarianceMatrix,
25                        MultipleCovarianceMatrices,CrossCovarianceMatrix)
26from progress_bar import progressinfo
27from slideshow import (basic_css, slideshow_css, HTMLSlideShow,
28                       image_slideshow_css, ImageHTMLSlideShow,
29                       SectionHTMLSlideShow, SectionImageHTMLSlideShow,
30                       image_slideshow, show_image_slideshow)
31
32from _symeig import SymeigException
33
34import mdp as _mdp
35# matrix multiplication function
36# we use an alias to be able to use the wrapper for the 'gemm' Lapack
37# function in the future
38mult = _mdp.numx.dot
39matmult = mult
40
41if _mdp.numx_description == 'scipy':
42    def matmult(a,b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0):
43        """Return alpha*(a*b) + beta*c.
44        a,b,c : matrices
45        alpha, beta: scalars
46        trans_a : 0 (a not transposed), 1 (a transposed),
47                  2 (a conjugate transposed)
48        trans_b : 0 (b not transposed), 1 (b transposed),
49                  2 (b conjugate transposed)
50        """
51        if c:
52            gemm,=_mdp.numx_linalg.get_blas_funcs(('gemm',),(a,b,c))
53        else:
54            gemm,=_mdp.numx_linalg.get_blas_funcs(('gemm',),(a,b))
55
56        return gemm(alpha, a, b, beta, c, trans_a, trans_b)
57
58# workaround to numpy issues with dtype behavior:
59# 'f' is upcasted at least in the following functions
60_inv = _mdp.numx_linalg.inv
61inv = lambda x: refcast(_inv(x), x.dtype)
62_pinv = _mdp.numx_linalg.pinv
63pinv = lambda x: refcast(_pinv(x), x.dtype)
64_solve = _mdp.numx_linalg.solve
65solve = lambda x, y: refcast(_solve(x, y), x.dtype)
66
67def svd(x, compute_uv = True):
68    """Wrap the numx SVD routine, so that it returns arrays of the correct
69    dtype and a SymeigException in case of failures."""
70    tc = x.dtype
71    try:
72        if compute_uv:
73            u, s, v = _mdp.numx_linalg.svd(x)
74            return refcast(u, tc), refcast(s, tc), refcast(v, tc)
75        else:
76            s = _mdp.numx_linalg.svd(x, compute_uv=False)
77            return refcast(s, tc)
78    except _mdp.numx_linalg.LinAlgError, exc:
79        raise SymeigException(str(exc))
80
81__all__ = ['CovarianceMatrix', 'DelayCovarianceMatrix','CrossCovarianceMatrix',
82           'MultipleCovarianceMatrices', 'QuadraticForm',
83           'QuadraticFormException',
84           'comb', 'cov2', 'dig_node', 'get_dtypes', 'get_node_size',
85           'hermitian', 'inv', 'mult', 'mult_diag', 'nongeneral_svd',
86           'norm2', 'permute', 'pinv', 'progressinfo',
87           'random_rot', 'refcast', 'rotate', 'scast', 'solve', 'sqrtm',
88           'svd', 'symrand', 'timediff', 'matmult',
89           'HTMLSlideShow', 'ImageHTMLSlideShow',
90           'basic_css', 'slideshow_css', 'image_slideshow_css',
91           'SectionHTMLSlideShow',
92           'SectionImageHTMLSlideShow', 'image_slideshow',
93           'lrep', 'rrep', 'irep',
94           'orthogonal_permutations', 'izip_stretched',
95           'weighted_choice', 'bool_to_sign', 'sign_to_bool',
96           'OrderedDict', 'TemporaryDirectory', 'gabor', 'fixup_namespace']
97
98def _without_prefix(name, prefix):
99    if name.startswith(prefix):
100        return name[len(prefix):]
101    else:
102        return None
103
104import os
105FIXUP_DEBUG = os.getenv('MDPNSDEBUG')
106
107def fixup_namespace(mname, names, old_modules, keep_modules=()):
108    """Update ``__module__`` attribute and remove ``old_modules`` from namespace
109
110    When classes are imported from implementation modules into the
111    package exporting them, the ``__module__`` attribute reflects the
112    place of definition. Splitting the code into separate files (and
113    thus modules) makes the implementation managable. Nevertheless, we
114    do not want the implementation modules to be visible and delete
115    their names from the package's namespace. This causes some
116    problems: when looking at the exported classes and other objects,
117    their ``__module__`` attribute points to something non-importable,
118    ``repr`` output and documentation do not show the module from
119    which they are supposed to be imported. The documentation
120    generators like epydoc and sphinx are also confused. To alleviate
121    those problems, the ``__module__`` attributes of all exported
122    classes defined in a "private" module and then exported elsewhere
123    are changed to the latter.
124
125    For each name in ``names``, if ``<mname>.<name>`` is accessible,
126    and if its ``__module__`` attribute is equal to one of the names
127    in ``old_modules``, it is changed to ``"<mname>"``. In other
128    words, all the ``__module__`` attributes of objects exported from
129    module ``<mname>`` are updated, iff they used to point to one of the
130    "private" modules in ``old_modules``.
131
132    This operation is performed not only for classes, but actually for
133    all objects with the ``__module__`` attribute, following the rules
134    stated above. The operation is also performed recursively, not
135    only for names in ``names``, but also for methods, inner classes,
136    and other attributes. This recursive invocation is necessary
137    because all the problems affecting top-level exported classes also
138    affect their attributes visible for the user, and especially
139    documented functions.
140
141    If ``names`` is ``None``, all public names in module ``<mname>``
142    (not starting with ``'_'``) are affected.
143
144    After the ``__module__`` attributes are changed, "private" modules given
145    in ``old_modules``, except for the ones in ``keep_modules``, are deleted
146    from the namespace of ``<mname>`` module.
147    """
148    import sys
149    module = sys.modules[mname]
150    if names is None:
151        names = [name for name in dir(module) if not name.startswith('_')]
152    if FIXUP_DEBUG:
153        print 'NAMESPACE FIXUP: %s (%s)' % (module, mname)
154    for name in names:
155        _fixup_namespace_item(module, mname, name, old_modules, '')
156
157    # take care of removing the module filenames
158    for filename in old_modules:
159        # skip names in keep modules
160        if filename in keep_modules:
161            continue
162        try:
163            delattr(module, filename)
164            if FIXUP_DEBUG:
165                print 'NAMESPACE FIXUP: deleting %s from %s' % (filename, module)
166        except AttributeError:
167            # if the name is not there, we are in a reload, so do not
168            # do anything
169            pass
170
171def _fixup_namespace_item(parent, mname, name, old_modules, path):
172    try:
173        item = getattr(parent, name)
174    except AttributeError:
175        if name.startswith('__'): # those sometimes fail unexplicably
176            return
177        else:
178            raise
179    current_name = getattr(item, '__module__', None)
180    if (current_name is not None and
181        _without_prefix(current_name, mname + '.') in old_modules):
182        if FIXUP_DEBUG:
183            print 'namespace fixup: {%s => %s}%s.%s' % (
184                current_name, mname, path, name)
185        try:
186            item.__module__ = mname
187        except AttributeError:
188            try:
189                item.im_func.__module__ = mname
190            except AttributeError, e:
191                if FIXUP_DEBUG:
192                    print 'namespace fixup failed: ', e
193            # don't recurse into functions anyway
194            return
195        subitems = [_name for _name in dir(item)
196                    if _name.startswith('__') or not _name.startswith('_')]
197        for subitem in subitems:
198            _fixup_namespace_item(item, mname, subitem, old_modules,
199                                  path + '.' + name)
200
201fixup_namespace(__name__, __all__,
202                ('routines',
203                 'introspection',
204                 'quad_forms',
205                 'covariance',
206                 'progress_bar',
207                 'slideshow',
208                 '_ordered_dict',
209                 'templet',
210                 'temporarydir',
211                 'os',
212                 ))
213