1"""
2Implementation of the univariate matrix polynomial.
3The algebraic class is
4
5M[t]/<t^D>
6
7where M is the ring of matrices and t an external parameter
8
9"""
10
11import math
12
13import numpy.linalg
14import numpy
15import scipy.linalg
16
17from ..base_type import Ring
18from .._npversion import NumpyVersion
19
20from .algorithms import RawAlgorithmsMixIn, broadcast_arrays_shape
21
22import operator
23
24from algopy import nthderiv
25import algopy.utils
26
27
28if NumpyVersion(numpy.version.version) >= '1.6.0':
29
30    def workaround_strides_function(x, y, fun):
31        """
32
33        peform the operation fun(x,y)
34
35        where fun = operator.iadd, operator.imul, operator.setitem, etc.
36
37        workaround for the bug
38        https://github.com/numpy/numpy/issues/2705
39
40        Replace this function once the bug has been fixed.
41
42        This function assumes that x and y have the same shape.
43
44
45        Parameters
46        ------------
47
48        x:      UTPM instance
49
50        y:      UTPM instance
51
52        fun:    function from the module operator
53
54        """
55
56        if x.shape != y.shape:
57            raise ValueError('x.shape != y.shape')
58
59        if x.ndim == 0:
60            fun(x, y)
61        else:
62            for i in range(x.shape[0]):
63                workaround_strides_function(x[i, ...], y[i, ...], fun)
64
65else:
66
67    def workaround_strides_function(x, y, fun):
68        fun(x, y)
69
70
71class UTPM(Ring, RawAlgorithmsMixIn):
72    """
73
74    UTPM == Univariate Taylor Polynomial of Matrices
75    This class implements univariate Taylor arithmetic on matrices, i.e.
76    [A]_D = \sum_{d=0}^{D-1} A_d T^d
77
78    Input:
79    in the most general form, the input is a 4-tensor.
80    We use the notation:
81    D: degree of the Taylor series
82    P: number of directions
83    N: number of rows of A_0
84    M: number of cols of A_0
85
86    shape([A]) = (D,P,N,M)
87    The reason for this choice is that the (N,M) matrix is the elementary type,
88    so that memory should be contiguous.
89    Then, at each operation, the code performed to compute
90    v_d has to be repeated for every direction.
91
92    E.g. a multiplication
93    [w] = [u]*[v] =
94    [[u_11, ..., u_1Ndir],
95    ...
96    [u_D1, ..., u_DNdir]]  +
97    [[v11, ..., v_1Ndir],
98    ...
99    [v_D1, ..., v_DNdir]] =
100    [[ u_11 + v_11, ..., u_1Ndir + v_1Ndir],
101    ...
102    [[ u_D1 + v_D1, ..., u_DNdir + v_DNdir]]
103
104    For ufuncs this arrangement is advantageous, because in this order,
105    memory chunks of size Ndir are used and the operation on each element is the
106    same. This is desireable to avoid cache misses.
107    See for example __mul__: there, operations of
108
109    self.data[:d+1,:,:,:]* rhs.data[d::-1,:,:,:]
110
111    has to be performed.
112    One can see, that contiguous memory blocks are used for such operations.
113
114    A disadvantage of this arrangement is: it seems unnatural.
115    It is easier to regard each direction separately.
116    """
117
118    __array_priority__ = 2
119
120    def __init__(self, X):
121        """
122
123        INPUT:
124        shape([X]) = (D,P,N,M)
125        """
126        Ndim = numpy.ndim(X)
127        if Ndim >= 2:
128            self.data = numpy.asarray(X)
129            self.data = self.data
130        else:
131            raise NotImplementedError()
132
133    def __getitem__(self, sl):
134        if not isinstance(sl, tuple):
135            sl = (sl,)
136        tmp = self.data.__getitem__((slice(None),slice(None)) + sl)
137        return self.__class__(tmp)
138
139    def __setitem__(self, sl, rhs):
140        if isinstance(rhs, UTPM):
141            if type(sl) == int or sl == Ellipsis or isinstance(sl, slice):
142                sl = (sl,)
143            x_data, y_data = UTPM._broadcast_arrays(self.data.__getitem__((slice(None),slice(None)) + sl), rhs.data)
144            return x_data.__setitem__(Ellipsis, y_data)
145        else:
146            if type(sl) == int or sl == Ellipsis or isinstance(sl, slice):
147                sl = (sl,)
148            self.data.__setitem__((slice(1,None),slice(None)) + sl, 0)
149            return self.data.__setitem__((0,slice(None)) + sl, rhs)
150
151
152    @property
153    def dtype(self):
154        return self.data.dtype
155
156    @classmethod
157    def pb___getitem__(cls, ybar, x, sl, y, out = None):
158        """
159        y = getitem(x, sl)
160
161        Warning:
162        this includes a workaround for tuples, e.g. for Q,R = qr(A)
163        where A,Q,R are Function objects
164        """
165        if out is None:
166            raise NotImplementedError('I\'m not sure that this makes sense')
167
168        # workaround for qr and eigh
169        if isinstance( out[0], tuple):
170            tmp = list(out[0])
171            tmp[sl] += ybar
172
173        # usual workflow
174        else:
175            # print 'out=\n', out[0][sl]
176            # print 'ybar=\n',ybar
177            out[0][sl] = ybar
178
179        return out
180
181    @classmethod
182    def pb_getitem(cls, ybar, x, sl, y, out = None):
183        # print 'ybar=\n',ybar
184        retval = cls.pb___getitem__(ybar, x, sl, y, out = out)
185        # print 'retval=\n',retval[0]
186        return retval
187
188
189    @classmethod
190    def as_utpm(cls, x):
191        """ tries to convert a container (e.g. list or numpy.array) with UTPM elements as instances to a UTPM instance"""
192
193        x_shp = numpy.shape(x)
194        xr = numpy.ravel(x)
195
196        # print 'x=', x
197        # print 'xr=',xr
198        # print 'x.dtype', x.dtype
199        D,P = xr[0].data.shape[:2]
200        shp = xr[0].data.shape[2:]
201
202        if not isinstance(shp, tuple): shp = (shp,)
203        if not isinstance(x_shp, tuple): x_shp = (x_shp,)
204
205        y = UTPM(numpy.zeros((D,P) + x_shp + shp))
206
207        yr = UTPM( y.data.reshape((D,P) + (numpy.prod(x_shp, dtype=int),) + shp))
208
209        # print yr.shape
210        # print yr.data.shape
211
212        for n in range(len(xr)):
213            # print yr[n].shape
214            # print xr[n].shape
215            yr[n] = xr[n]
216
217        return y
218
219
220    def get_flat(self):
221        return UTPM(self.data.reshape(self.data.shape[:2] + (numpy.prod(self.data.shape[2:]),) ))
222
223    flat = property(get_flat)
224
225    def coeff_op(self, sl, shp):
226        """
227        operation to extract UTP coefficients of x
228        defined by the slice sl creates a new
229        UTPM instance where the coefficients have the shape as defined
230        by shp
231
232        Parameters
233        ----------
234        x: UTPM instance
235        sl: tuple of slice instance
236        shp: tuple
237
238        Returns
239        -------
240        UTPM instance
241
242        """
243
244        tmp = self.data.__getitem__(sl)
245        tmp = tmp.reshape(shp)
246        return self.__class__(tmp)
247
248    @classmethod
249    def pb_coeff_op(cls, ybar, x, sl, shp, out = None):
250
251        if out is None:
252            D,P = x.data.shape[:2]
253            xbar = x.zeros_like()
254
255        else:
256            xbar = out[0]
257
258        # step 1: revert reshape
259        old_shp = x.data.__getitem__(sl).shape
260        tmp_data = ybar.data.reshape(old_shp)
261
262        # print('tmp_data.shape=',tmp_data.shape)
263
264        # step 2: revert getitem
265        tmp2 = xbar.data[::-1].__getitem__(sl)
266        tmp2 += tmp_data[::-1,...]
267
268        return xbar
269
270
271
272    @classmethod
273    def pb___setitem__(cls, y, sl, x, out = None):
274        """
275        y.__setitem(sl,x)
276        """
277
278        if out is None:
279            raise NotImplementedError('I\'m not sure if this makes sense')
280
281        ybar, dummy, xbar = out
282        # print 'xbar =', xbar
283        # print 'ybar =', ybar
284        xbar += ybar[sl]
285        ybar[sl].data[...] = 0.
286        # print 'funcargs=',funcargs
287        # print y[funcargs[0]]
288
289    @classmethod
290    def pb_setitem(cls, y, sl, x, out = None):
291        return cls.pb___setitem__(y, sl, x, out = out)
292
293    def __add__(self,rhs):
294        if numpy.isscalar(rhs):
295            dtype = numpy.promote_types(self.data.dtype, type(rhs))
296            retval = UTPM(numpy.zeros(self.data.shape, dtype=dtype))
297            retval.data[...] = self.data
298            retval.data[0,:] += rhs
299            return retval
300
301        elif isinstance(rhs,numpy.ndarray) and rhs.dtype == object:
302
303            if not isinstance(rhs.flatten()[0], UTPM):
304                err_str = 'you are trying to perform an operation involving 1) a UTPM instance and 2)a numpy.ndarray with elements of type %s\n'%type(rhs.flatten()[0])
305                err_str+= 'this operation is not supported!\n'
306                raise NotImplementedError(err_str)
307            else:
308                err_str = 'binary operations between UTPM instances and object arrays are not supported'
309                raise NotImplementedError(err_str)
310
311        elif isinstance(rhs, numpy.ndarray):
312            rhs_shape = rhs.shape
313            if numpy.isscalar(rhs_shape):
314                rhs_shape = (rhs_shape,)
315            x_data, y_data = UTPM._broadcast_arrays(self.data, rhs.reshape((1,1)+rhs_shape))
316            dtype = numpy.promote_types(x_data.dtype, y_data.dtype)
317            z_data = numpy.zeros(x_data.shape, dtype=dtype)
318            z_data[...] = x_data
319            z_data[0] += y_data[0]
320            return UTPM(z_data)
321
322        else:
323            x_data, y_data = UTPM._broadcast_arrays(self.data, rhs.data)
324            return UTPM(x_data + y_data)
325
326    def __sub__(self,rhs):
327        if numpy.isscalar(rhs):
328            dtype = numpy.promote_types(self.data.dtype, type(rhs))
329            retval = UTPM(numpy.zeros(self.data.shape, dtype=dtype))
330            retval.data[...] = self.data
331            retval.data[0,:] -= rhs
332            return retval
333
334        elif isinstance(rhs,numpy.ndarray) and rhs.dtype == object:
335            if not isinstance(rhs.flatten()[0], UTPM):
336                err_str = 'you are trying to perform an operation involving 1) a UTPM instance and 2)a numpy.ndarray with elements of type %s\n'%type(rhs.flatten()[0])
337                err_str+= 'this operation is not supported!\n'
338                raise NotImplementedError(err_str)
339            else:
340                err_str = 'binary operations between UTPM instances and object arrays are not supported'
341                raise NotImplementedError(err_str)
342
343        elif isinstance(rhs, numpy.ndarray):
344            rhs_shape = rhs.shape
345            if numpy.isscalar(rhs_shape):
346                rhs_shape = (rhs_shape,)
347            x_data, y_data = UTPM._broadcast_arrays(self.data, rhs.reshape((1,1)+rhs_shape))
348            dtype = numpy.promote_types(x_data.dtype, y_data.dtype)
349            z_data = numpy.zeros(x_data.shape, dtype=dtype)
350            z_data[...] = x_data
351            z_data[0] -= y_data[0]
352            return UTPM(z_data)
353
354        else:
355            x_data, y_data = UTPM._broadcast_arrays(self.data, rhs.data)
356            return UTPM(x_data - y_data)
357
358    def __mul__(self,rhs):
359        if numpy.isscalar(rhs):
360            return UTPM( self.data * rhs)
361
362        elif isinstance(rhs,numpy.ndarray) and rhs.dtype == object:
363            if not isinstance(rhs.flatten()[0], UTPM):
364                err_str = 'you are trying to perform an operation involving 1) a UTPM instance and 2)a numpy.ndarray with elements of type %s\n'%type(rhs.flatten()[0])
365                err_str+= 'this operation is not supported!\n'
366                raise NotImplementedError(err_str)
367            else:
368                err_str = 'binary operations between UTPM instances and object arrays are not supported'
369                raise NotImplementedError(err_str)
370
371        elif isinstance(rhs,numpy.ndarray):
372            rhs_shape = rhs.shape
373            if numpy.isscalar(rhs_shape):
374                rhs_shape = (rhs_shape,)
375            x_data, y_data = UTPM._broadcast_arrays(self.data, rhs.reshape((1,1)+rhs_shape))
376            return UTPM(x_data * y_data)
377
378        x_data, y_data = UTPM._broadcast_arrays(self.data, rhs.data)
379        dtype = numpy.promote_types(x_data.dtype, y_data.dtype)
380        z_data = numpy.zeros(x_data.shape, dtype=dtype)
381        self._mul(x_data, y_data, z_data)
382        return self.__class__(z_data)
383
384    def __truediv__(self,rhs):
385        if numpy.isscalar(rhs):
386            return UTPM( self.data/rhs)
387
388        elif isinstance(rhs,numpy.ndarray) and rhs.dtype == object:
389            if not isinstance(rhs.flatten()[0], UTPM):
390                err_str = 'you are trying to perform an operation involving 1) a UTPM instance and 2)a numpy.ndarray with elements of type %s\n'%type(rhs.flatten()[0])
391                err_str+= 'this operation is not supported!\n'
392                raise NotImplementedError(err_str)
393            else:
394                err_str = 'binary operations between UTPM instances and object arrays are not supported'
395                raise NotImplementedError(err_str)
396
397
398        elif isinstance(rhs,numpy.ndarray):
399            rhs_shape = rhs.shape
400            if numpy.isscalar(rhs_shape):
401                rhs_shape = (rhs_shape,)
402            x_data, y_data = UTPM._broadcast_arrays(self.data, rhs.reshape((1,1)+rhs_shape))
403            return UTPM(x_data / y_data)
404
405        x_data, y_data = UTPM._broadcast_arrays(self.data, rhs.data)
406        dtype = numpy.promote_types(x_data.dtype, y_data.dtype)
407        z_data = numpy.zeros(x_data.shape, dtype=dtype)
408        self._truediv(x_data, y_data, z_data)
409        return self.__class__(z_data)
410
411    def __floordiv__(self, rhs):
412        """
413        self // rhs
414
415        use L'Hopital's rule
416        """
417
418        x_data, y_data = UTPM._broadcast_arrays(self.data, rhs.data)
419        dtype = numpy.promote_types(x_data.dtype, y_data.dtype)
420        z_data = numpy.zeros(x_data.shape, dtype=dtype)
421        self._floordiv(x_data, y_data, z_data)
422        return self.__class__(z_data)
423
424    def __pow__(self,r):
425        if isinstance(r, UTPM):
426            return UTPM.exp(UTPM.log(self)*r)
427        else:
428            x_data = self.data
429            y_data = numpy.zeros_like(x_data)
430            self._pow_real(x_data, r, y_data)
431            return self.__class__(y_data)
432
433    def __rpow__(self,r):
434        return UTPM.exp(numpy.log(r)*self)
435
436
437    @classmethod
438    def pb___pow__(cls, ybar, x, r, y, out = None):
439        if out is None:
440            D,P = x.data.shape[:2]
441            xbar = x.zeros_like()
442
443        else:
444            xbar = out[0]
445
446        if isinstance(r, cls):
447            raise NotImplementedError('r must be int or float, or use the identity x**y = exp(log(x)*y)')
448
449        cls._pb_pow_real(ybar.data, x.data, r, y.data, out = xbar.data)
450        return xbar
451
452    @classmethod
453    def pb_pow(cls, ybar, x, r, y, out = None):
454        retval = cls.pb___pow__(ybar, x, r, y, out = out)
455        return retval
456
457
458    def __radd__(self,rhs):
459        return self + rhs
460
461    def __rsub__(self, other):
462        return -self + other
463
464    def __rmul__(self,rhs):
465        return self * rhs
466
467    def __rtruediv__(self, rhs):
468        tmp = self.zeros_like()
469        tmp.data[0,...] = rhs
470        return tmp/self
471
472    def __iadd__(self,rhs):
473        if isinstance(rhs,numpy.ndarray) and rhs.dtype == object:
474            raise NotImplementedError('should implement that')
475
476        elif numpy.isscalar(rhs) or isinstance(rhs,numpy.ndarray):
477            self.data[0,...] += rhs
478        else:
479            self_data, rhs_data = UTPM._broadcast_arrays(self.data, rhs.data)
480            # self_data[...] += rhs_data[...]
481            numpy.add(self_data, rhs_data, out=self_data, casting="unsafe")
482        return self
483
484    def __isub__(self,rhs):
485        if isinstance(rhs,numpy.ndarray) and rhs.dtype == object:
486            raise NotImplementedError('should implement that')
487
488        elif numpy.isscalar(rhs) or isinstance(rhs,numpy.ndarray):
489            self.data[0,...] -= rhs
490        else:
491            self_data, rhs_data = UTPM._broadcast_arrays(self.data, rhs.data)
492            self_data[...] -= rhs_data[...]
493        return self
494
495    def __imul__(self,rhs):
496        (D,P) = self.data.shape[:2]
497
498        if isinstance(rhs,numpy.ndarray) and rhs.dtype == object:
499            raise NotImplementedError('should implement that')
500
501        elif numpy.isscalar(rhs) or isinstance(rhs,numpy.ndarray):
502            for d in range(D):
503                for p in range(P):
504                    self.data[d,p,...] *= rhs
505        else:
506            for d in range(D)[::-1]:
507                for p in range(P):
508                    self.data[d,p,...] *= rhs.data[0,p,...]
509                    for c in range(d):
510                        self.data[d,p,...] += self.data[c,p,...] * rhs.data[d-c,p,...]
511        return self
512
513    def __itruediv__(self,rhs):
514        (D,P) = self.data.shape[:2]
515        if isinstance(rhs,numpy.ndarray) and rhs.dtype == object:
516            raise NotImplementedError('should implement that')
517
518        elif numpy.isscalar(rhs) or isinstance(rhs,numpy.ndarray):
519            self.data[...] /= rhs
520        else:
521            retval = self.clone()
522            for d in range(D):
523                retval.data[d,:,...] = 1./ rhs.data[0,:,...] * ( self.data[d,:,...] - numpy.sum(retval.data[:d,:,...] * rhs.data[d:0:-1,:,...], axis=0))
524            self.data[...] = retval.data[...]
525        return self
526
527    __div__ = __truediv__
528    __idiv__ = __itruediv__
529    __rdiv__ = __rtruediv__
530
531    def sqrt(self):
532        retval = self.clone()
533        self._sqrt(self.data, out = retval.data)
534        return retval
535
536    @classmethod
537    def pb_sqrt(cls, ybar, x, y, out=None):
538        """ computes bar y dy = bar x dx in UTP arithmetic"""
539        if out is None:
540            D,P = x.data.shape[:2]
541            xbar = x.zeros_like()
542
543        else:
544            xbar, = out
545
546        cls._pb_sqrt(ybar.data, x.data, y.data, out = xbar.data)
547        return out
548
549    def exp(self):
550        """ computes y = exp(x) in UTP arithmetic"""
551
552        retval = self.clone()
553        self._exp(self.data, out = retval.data)
554        return retval
555
556    @classmethod
557    def pb_exp(cls, ybar, x, y, out=None):
558        """ computes bar y dy = bar x dx in UTP arithmetic"""
559        if out is None:
560            D,P = x.data.shape[:2]
561            xbar = x.zeros_like()
562
563        else:
564            xbar, = out
565
566        cls._pb_exp(ybar.data, x.data, y.data, out = xbar.data)
567        return out
568
569    def expm1(self):
570        """ computes y = expm1(x) in UTP arithmetic"""
571
572        retval = self.clone()
573        self._expm1(self.data, out = retval.data)
574        return retval
575
576    @classmethod
577    def pb_expm1(cls, ybar, x, y, out=None):
578        """ computes bar y dy = bar x dx in UTP arithmetic"""
579        if out is None:
580            D,P = x.data.shape[:2]
581            xbar = x.zeros_like()
582
583        else:
584            xbar, = out
585
586        cls._pb_expm1(ybar.data, x.data, y.data, out = xbar.data)
587        return out
588
589    def log(self):
590        """ computes y = log(x) in UTP arithmetic"""
591        retval = self.clone()
592        self._log(self.data, out = retval.data)
593        return retval
594
595    @classmethod
596    def pb_log(cls, ybar, x, y, out=None):
597        """ computes bar y dy = bar x dx in UTP arithmetic"""
598        if out is None:
599            D,P = x.data.shape[:2]
600            xbar = x.zeros_like()
601
602        else:
603            xbar, = out
604
605        cls._pb_log(ybar.data, x.data, y.data, out = xbar.data)
606        return xbar
607
608    def log1p(self):
609        """ computes y = log1p(x) in UTP arithmetic"""
610        retval = self.clone()
611        self._log1p(self.data, out = retval.data)
612        return retval
613
614    @classmethod
615    def pb_log1p(cls, ybar, x, y, out=None):
616        """ computes bar y dy = bar x dx in UTP arithmetic"""
617        if out is None:
618            D,P = x.data.shape[:2]
619            xbar = x.zeros_like()
620
621        else:
622            xbar, = out
623
624        cls._pb_log1p(ybar.data, x.data, y.data, out = xbar.data)
625        return out
626
627    def sincos(self):
628        """ simultanteously computes s = sin(x) and c = cos(x) in UTP arithmetic"""
629        retsin = self.clone()
630        retcos = self.clone()
631        self._sincos(self.data, out = (retsin.data, retcos.data))
632        return retsin, retcos
633
634    def sin(self):
635        retval = self.clone()
636        tmp = self.clone()
637        self._sincos(self.data, out = (retval.data, tmp.data))
638        return retval
639
640    @classmethod
641    def pb_sin(cls, sbar, x, s,  out = None):
642        if out is None:
643            D,P = x.data.shape[:2]
644            xbar = x.zeros_like()
645
646        else:
647            xbar, = out
648
649        c = x.cos()
650        cbar = x.zeros_like()
651        cls._pb_sincos(sbar.data, cbar.data, x.data, s.data, c.data, out = xbar.data)
652        return out
653
654    def cos(self):
655        retval = self.clone()
656        tmp = self.clone()
657        self._sincos(self.data, out = (tmp.data, retval.data))
658        return retval
659
660    @classmethod
661    def pb_cos(cls, cbar, x, c,  out = None):
662        if out is None:
663            D,P = x.data.shape[:2]
664            xbar = x.zeros_like()
665
666        else:
667            xbar, = out
668
669        s = x.sin()
670        sbar = x.zeros_like()
671        cls._pb_sincos(sbar.data, cbar.data, x.data, s.data, c.data, out = xbar.data)
672        return out
673
674
675    def tansec2(self):
676        """ computes simultaneously y = tan(x) and z = sec^2(x)  in UTP arithmetic"""
677        rettan = self.clone()
678        retsec = self.clone()
679        self._tansec2(self.data, out = (rettan.data, retsec.data))
680        return rettan, retset
681
682    def tan(self):
683        retval = self.zeros_like()
684        tmp = self.zeros_like()
685        self._tansec2(self.data, out = (retval.data, tmp.data))
686        return retval
687
688    @classmethod
689    def pb_tan(cls, ybar, x, y,  out = None):
690        if out is None:
691            D,P = x.data.shape[:2]
692            xbar = x.zeros_like()
693
694        else:
695            xbar, = out
696
697        z = 1./x.cos(); z = z * z
698        zbar = x.zeros_like()
699        cls._pb_tansec(ybar.data, zbar.data, x.data, y.data, z.data, out = xbar.data)
700        return out
701
702    @classmethod
703    def dpm_hyp1f1(cls, a, b, x):
704        """ computes y = hyp1f1(a, b, x) in UTP arithmetic"""
705
706        retval = x.clone()
707        cls._dpm_hyp1f1(a, b, x.data, out = retval.data)
708        return retval
709
710    @classmethod
711    def pb_dpm_hyp1f1(cls, ybar, a, b, x, y, out=None):
712        """ computes bar y dy = bar x dx in UTP arithmetic"""
713        if out is None:
714            D,P = x.data.shape[:2]
715            xbar = x.zeros_like()
716
717        else:
718            # out = (abar, bbar, xbar)
719            xbar = out[2]
720
721        cls._pb_dpm_hyp1f1(ybar.data, a, b, x.data, y.data, out = xbar.data)
722
723        return xbar
724
725    @classmethod
726    def hyp1f1(cls, a, b, x):
727        """ computes y = hyp1f1(a, b, x) in UTP arithmetic"""
728
729        retval = x.clone()
730        cls._hyp1f1(a, b, x.data, out = retval.data)
731        return retval
732
733    @classmethod
734    def pb_hyp1f1(cls, ybar, a, b, x, y, out=None):
735        """ computes bar y dy = bar x dx in UTP arithmetic"""
736        if out is None:
737            D,P = x.data.shape[:2]
738            xbar = x.zeros_like()
739
740        else:
741            # out = (abar, bbar, xbar)
742            xbar = out[2]
743
744        cls._pb_hyp1f1(ybar.data, a, b, x.data, y.data, out = xbar.data)
745
746        return xbar
747
748    @classmethod
749    def hyperu(cls, a, b, x):
750        """ computes y = hyperu(a, b, x) in UTP arithmetic"""
751        retval = x.clone()
752        cls._hyperu(a, b, x.data, out = retval.data)
753        return retval
754
755    @classmethod
756    def pb_hyperu(cls, ybar, a, b, x, y, out=None):
757        """ computes bar y dy = bar x dx in UTP arithmetic"""
758        if out is None:
759            D,P = x.data.shape[:2]
760            xbar = x.zeros_like()
761        else:
762            # out = (abar, bbar, xbar)
763            xbar = out[2]
764        cls._pb_hyperu(ybar.data, a, b, x.data, y.data, out = xbar.data)
765        return xbar
766
767    @classmethod
768    def botched_clip(cls, a_min, a_max, x):
769        """ computes y = botched_clip(a_min, a_max, x) in UTP arithmetic"""
770        retval = x.clone()
771        cls._botched_clip(a_min, a_max, x.data, out = retval.data)
772        return retval
773
774    @classmethod
775    def pb_botched_clip(cls, ybar, a_min, a_max, x, y, out=None):
776        """ computes bar y dy = bar x dx in UTP arithmetic"""
777        if out is None:
778            D,P = x.data.shape[:2]
779            xbar = x.zeros_like()
780        else:
781            # out = (aminbar, amaxbar, xbar)
782            xbar = out[2]
783        cls._pb_botched_clip(
784                ybar.data, a_min, a_max, x.data, y.data, out = xbar.data)
785        return xbar
786
787    @classmethod
788    def dpm_hyp2f0(cls, a1, a2, x):
789        """ computes y = hyp2f0(a1, a2, x) in UTP arithmetic"""
790
791        retval = x.clone()
792        cls._dpm_hyp2f0(a1, a2, x.data, out = retval.data)
793        return retval
794
795    @classmethod
796    def pb_dpm_hyp2f0(cls, ybar, a1, a2, x, y, out=None):
797        """ computes bar y dy = bar x dx in UTP arithmetic"""
798        if out is None:
799            D,P = x.data.shape[:2]
800            xbar = x.zeros_like()
801
802        else:
803            # out = (a1bar, a2bar, xbar)
804            xbar = out[2]
805
806        cls._pb_dpm_hyp2f0(ybar.data, a1, a2, x.data, y.data, out = xbar.data)
807
808        return xbar
809
810    @classmethod
811    def hyp2f0(cls, a1, a2, x):
812        """ computes y = hyp2f0(a1, a2, x) in UTP arithmetic"""
813
814        retval = x.clone()
815        cls._hyp2f0(a1, a2, x.data, out = retval.data)
816        return retval
817
818    @classmethod
819    def pb_hyp2f0(cls, ybar, a1, a2, x, y, out=None):
820        """ computes bar y dy = bar x dx in UTP arithmetic"""
821        if out is None:
822            D,P = x.data.shape[:2]
823            xbar = x.zeros_like()
824
825        else:
826            # out = (a1bar, a2bar, xbar)
827            xbar = out[2]
828
829        cls._pb_hyp2f0(ybar.data, a1, a2, x.data, y.data, out = xbar.data)
830
831        return xbar
832
833    @classmethod
834    def hyp0f1(cls, b, x):
835        """ computes y = hyp0f1(b, x) in UTP arithmetic"""
836
837        retval = x.clone()
838        cls._hyp0f1(b, x.data, out = retval.data)
839        return retval
840
841    @classmethod
842    def pb_hyp0f1(cls, ybar, b, x, y, out=None):
843        """ computes bar y dy = bar x dx in UTP arithmetic"""
844        if out is None:
845            D,P = x.data.shape[:2]
846            xbar = x.zeros_like()
847
848        else:
849            # out = (bbar, xbar)
850            xbar = out[1]
851
852        cls._pb_hyp0f1(ybar.data, b, x.data, y.data, out = xbar.data)
853
854        return xbar
855
856    @classmethod
857    def polygamma(cls, n, x):
858        """ computes y = polygamma(n, x) in UTP arithmetic"""
859
860        retval = x.clone()
861        cls._polygamma(n, x.data, out = retval.data)
862        return retval
863
864    @classmethod
865    def pb_polygamma(cls, ybar, n, x, y, out=None):
866        """ computes bar y dy = bar x dx in UTP arithmetic"""
867        if out is None:
868            D,P = x.data.shape[:2]
869            xbar = x.zeros_like()
870
871        else:
872            # out = (nbar, xbar)
873            xbar = out[1]
874
875        cls._pb_polygamma(ybar.data, n, x.data, y.data, out = xbar.data)
876
877        return xbar
878
879    @classmethod
880    def psi(cls, x):
881        """ computes y = psi(x) in UTP arithmetic"""
882
883        retval = x.clone()
884        cls._psi(x.data, out = retval.data)
885        return retval
886
887    @classmethod
888    def pb_psi(cls, ybar, x, y, out=None):
889        """ computes bar y dy = bar x dx in UTP arithmetic"""
890        if out is None:
891            D,P = x.data.shape[:2]
892            xbar = x.zeros_like()
893
894        else:
895            xbar, = out
896
897        cls._pb_psi(ybar.data, x.data, y.data, out = xbar.data)
898        return xbar
899
900    @classmethod
901    def reciprocal(cls, x):
902        """ computes y = reciprocal(x) in UTP arithmetic"""
903
904        retval = x.clone()
905        cls._reciprocal(x.data, out = retval.data)
906        return retval
907
908    @classmethod
909    def pb_reciprocal(cls, ybar, x, y, out=None):
910        """ computes bar y dy = bar x dx in UTP arithmetic"""
911        if out is None:
912            D,P = x.data.shape[:2]
913            xbar = x.zeros_like()
914
915        else:
916            xbar, = out
917
918        cls._pb_reciprocal(ybar.data, x.data, y.data, out = xbar.data)
919        return xbar
920
921    @classmethod
922    def gammaln(cls, x):
923        """ computes y = gammaln(x) in UTP arithmetic"""
924
925        retval = x.clone()
926        cls._gammaln(x.data, out = retval.data)
927        return retval
928
929    @classmethod
930    def pb_gammaln(cls, ybar, x, y, out=None):
931        """ computes bar y dy = bar x dx in UTP arithmetic"""
932        if out is None:
933            D,P = x.data.shape[:2]
934            xbar = x.zeros_like()
935
936        else:
937            xbar, = out
938
939        cls._pb_gammaln(ybar.data, x.data, y.data, out = xbar.data)
940        return xbar
941
942    @classmethod
943    def minimum(cls, x, y):
944        # FIXME: this typechecking is probably not flexible enough
945        # FIXME: also add pullback
946        if isinstance(x, UTPM) and isinstance(y, UTPM):
947            return UTPM(cls._minimum(x.data, y.data))
948        elif isinstance(x, numpy.ndarray) and isinstance(y, numpy.ndarray):
949            return numpy.minimum(x, y)
950        else:
951            raise NotImplementedError(
952                    'this combination of types is not yet implemented')
953
954    @classmethod
955    def maximum(cls, x, y):
956        # FIXME: this typechecking is probably not flexible enough
957        # FIXME: also add pullback
958        if isinstance(x, UTPM) and isinstance(y, UTPM):
959            return UTPM(cls._maximum(x.data, y.data))
960        elif isinstance(x, numpy.ndarray) and isinstance(y, numpy.ndarray):
961            return numpy.maximum(x, y)
962        else:
963            raise NotImplementedError(
964                    'this combination of types is not yet implemented')
965
966    @classmethod
967    def real(cls, x):
968        """ UTPM equivalent to numpy.real """
969        return cls(x.data.real)
970
971    @classmethod
972    def pb_real(cls, ybar, x, y, out=None):
973        if out is None:
974            D,P = x.data.shape[:2]
975            xbar = x.zeros_like()
976
977        else:
978            xbar, = out
979
980        xbar.data.real = ybar.data
981
982    @classmethod
983    def imag(cls, x):
984        """ UTPM equivalent to numpy.imag """
985        return cls(x.data.imag)
986
987    @classmethod
988    def pb_imag(cls, ybar, x, y, out=None):
989        if out is None:
990            D,P = x.data.shape[:2]
991            xbar = x.zeros_like()
992
993        else:
994            xbar, = out
995        xbar.data.imag = -ybar.data
996
997
998    @classmethod
999    def absolute(cls, x):
1000        """ computes y = absolute(x) in UTP arithmetic"""
1001
1002        retval = x.clone()
1003        cls._absolute(x.data, out = retval.data)
1004        return retval
1005
1006    @classmethod
1007    def pb_absolute(cls, ybar, x, y, out=None):
1008        """ computes ybar * ydot = xbar * xdot in UTP arithmetic"""
1009
1010        if out is None:
1011            D,P = x.data.shape[:2]
1012            xbar = x.zeros_like()
1013
1014        else:
1015            xbar, = out
1016
1017        cls._pb_absolute(ybar.data, x.data, y.data, out = xbar.data)
1018        return xbar
1019
1020    @classmethod
1021    def negative(cls, x):
1022        """ computes y = negative(x) in UTP arithmetic"""
1023
1024        retval = x.clone()
1025        cls._negative(x.data, out = retval.data)
1026        return retval
1027
1028    @classmethod
1029    def pb_negative(cls, ybar, x, y, out=None):
1030        """ computes ybar * ydot = xbar * xdot in UTP arithmetic"""
1031
1032        if out is None:
1033            D,P = x.data.shape[:2]
1034            xbar = x.zeros_like()
1035
1036        else:
1037            xbar, = out
1038
1039        cls._pb_negative(ybar.data, x.data, y.data, out = xbar.data)
1040        return xbar
1041
1042    @classmethod
1043    def square(cls, x):
1044        """ computes y = square(x) in UTP arithmetic"""
1045
1046        retval = x.clone()
1047        cls._square(x.data, out = retval.data)
1048        return retval
1049
1050    @classmethod
1051    def pb_square(cls, ybar, x, y, out=None):
1052        """ computes ybar * ydot = xbar * xdot in UTP arithmetic"""
1053
1054        if out is None:
1055            D,P = x.data.shape[:2]
1056            xbar = x.zeros_like()
1057
1058        else:
1059            xbar, = out
1060
1061        cls._pb_square(ybar.data, x.data, y.data, out = xbar.data)
1062        return xbar
1063
1064    @classmethod
1065    def erf(cls, x):
1066        """ computes y = erf(x) in UTP arithmetic"""
1067
1068        retval = x.clone()
1069        cls._erf(x.data, out = retval.data)
1070        return retval
1071
1072
1073    @classmethod
1074    def pb_erf(cls, ybar, x, y, out=None):
1075        """ computes ybar * ydot = xbar * xdot in UTP arithmetic"""
1076
1077        if out is None:
1078            D,P = x.data.shape[:2]
1079            xbar = x.zeros_like()
1080
1081        else:
1082            xbar, = out
1083
1084        cls._pb_erf(ybar.data, x.data, y.data, out = xbar.data)
1085        return xbar
1086
1087
1088    @classmethod
1089    def erfi(cls, x):
1090        """ computes y = erfi(x) in UTP arithmetic"""
1091
1092        retval = x.clone()
1093        cls._erfi(x.data, out = retval.data)
1094        return retval
1095
1096    @classmethod
1097    def pb_erfi(cls, ybar, x, y, out=None):
1098        """ computes ybar * ydot = xbar * xdot in UTP arithmetic"""
1099
1100        if out is None:
1101            D,P = x.data.shape[:2]
1102            xbar = x.zeros_like()
1103
1104        else:
1105            xbar, = out
1106
1107        cls._pb_erfi(ybar.data, x.data, y.data, out = xbar.data)
1108        return xbar
1109
1110
1111    @classmethod
1112    def dawsn(cls, x):
1113        """ computes y = dawsn(x) in UTP arithmetic"""
1114
1115        retval = x.clone()
1116        cls._dawsn(x.data, out = retval.data)
1117        return retval
1118
1119    @classmethod
1120    def pb_dawsn(cls, ybar, x, y, out=None):
1121        """ computes ybar * ydot = xbar * xdot in UTP arithmetic"""
1122
1123        if out is None:
1124            D,P = x.data.shape[:2]
1125            xbar = x.zeros_like()
1126
1127        else:
1128            xbar, = out
1129
1130        cls._pb_dawsn(ybar.data, x.data, y.data, out = xbar.data)
1131        return xbar
1132
1133
1134    @classmethod
1135    def logit(cls, x):
1136        """ computes y = logit(x) in UTP arithmetic"""
1137
1138        retval = x.clone()
1139        cls._logit(x.data, out = retval.data)
1140        return retval
1141
1142    @classmethod
1143    def pb_logit(cls, ybar, x, y, out=None):
1144        """ computes ybar * ydot = xbar * xdot in UTP arithmetic"""
1145
1146        if out is None:
1147            D,P = x.data.shape[:2]
1148            xbar = x.zeros_like()
1149
1150        else:
1151            xbar, = out
1152
1153        cls._pb_logit(ybar.data, x.data, y.data, out = xbar.data)
1154        return xbar
1155
1156    @classmethod
1157    def expit(cls, x):
1158        """ computes y = expit(x) in UTP arithmetic"""
1159
1160        retval = x.clone()
1161        cls._expit(x.data, out = retval.data)
1162        return retval
1163
1164    @classmethod
1165    def pb_expit(cls, ybar, x, y, out=None):
1166        """ computes ybar * ydot = xbar * xdot in UTP arithmetic"""
1167
1168        if out is None:
1169            D,P = x.data.shape[:2]
1170            xbar = x.zeros_like()
1171
1172        else:
1173            xbar, = out
1174
1175        cls._pb_expit(ybar.data, x.data, y.data, out = xbar.data)
1176        return xbar
1177
1178
1179    def sum(self, axis=None, dtype=None, out=None):
1180        if dtype is not None or out is not None:
1181            raise NotImplementedError('not implemented yet')
1182
1183        if axis is None:
1184            tmp = numpy.prod(self.data.shape[2:])
1185            return UTPM(numpy.sum(self.data.reshape(self.data.shape[:2] + (tmp,)), axis = 2))
1186        else:
1187            if axis < 0:
1188                a = self.data.ndim + axis
1189            else:
1190                a = axis + 2
1191            return UTPM(numpy.sum(self.data, axis = a))
1192
1193    @classmethod
1194    def pb_sum(cls, ybar, x, y, axis, dtype, out2, out = None):
1195
1196        if out is None:
1197            D,P = x.data.shape[:2]
1198            xbar = x.zeros_like()
1199
1200        else:
1201            xbar = out[0]
1202
1203        if axis is None:
1204
1205            tmp = xbar.data.T
1206            tmp += ybar.data.T
1207
1208        else:
1209
1210            if axis < 0:
1211                a = x.data.ndim + axis
1212
1213            else:
1214                a = axis + 2
1215
1216            shp = list(x.data.shape)
1217            shp[a] = 1
1218            tmp = ybar.data.reshape(shp)
1219            xbar.data += tmp
1220
1221        return xbar
1222
1223    def prod(self):
1224        x = self
1225        D,P = x.data.shape[:2]
1226        y = UTPM(numpy.zeros((D,P), dtype=x.data.dtype))
1227        y.data[0,:] = 1.
1228        for i in range(0, x.size):
1229            y *= x[i]
1230        return y
1231
1232    @classmethod
1233    def pb_prod(cls, ybar, x, y, out=None):
1234        D,P = x.data.shape[:2]
1235        if out is None:
1236            xbar = x.zeros_like()
1237
1238        else:
1239            xbar, = out
1240
1241        # forward and store intermediates
1242        z = x.zeros_like()
1243        zbar = x.zeros_like()
1244        z.data[0,:, 0] = 1.
1245        z[0] = x[0]
1246        for i in range(1, x.size):
1247            z[i] = z[i-1]*x[i]
1248
1249        # reverse
1250        zbar[x.size-1] = ybar
1251        for i in range(x.size-1, 0, -1):
1252            zbar[i-1] += zbar[i]*x[i]
1253            xbar[i]   += zbar[i]*z[i-1]
1254        xbar[0] = zbar[0]
1255        return xbar
1256
1257        # z = y.copy()
1258        # zbar = ybar.copy()
1259        # for i in range(x.size-1, -1, -1):
1260        #     xbar[i] += ybar*z
1261        #     zbar *= x[i]
1262        #     z /= x[i]
1263        # return xbar
1264
1265
1266    @classmethod
1267    def pb_sincos(cls, sbar, cbar, x, s, c, out = None):
1268        if out is None:
1269            D,P = x.data.shape[:2]
1270            xbar = x.zeros_like()
1271
1272        else:
1273            xbar, = out
1274
1275        cls._pb_sincos(sbar.data, cbar.data, x.data, s.data, c.data, out = xbar.data)
1276
1277        return out
1278
1279    def arcsin(self):
1280        """ computes y = arcsin(x) in UTP arithmetic"""
1281        rety = self.clone()
1282        retz = self.clone()
1283        self._arcsin(self.data, out = (rety.data, retz.data))
1284        return rety
1285
1286    def arccos(self):
1287        """ computes y = arccos(x) in UTP arithmetic"""
1288        rety = self.clone()
1289        retz = self.clone()
1290        self._arccos(self.data, out = (rety.data, retz.data))
1291        return rety
1292
1293    def arctan(self):
1294        """ computes y = arctan(x) in UTP arithmetic"""
1295        rety = self.clone()
1296        retz = self.clone()
1297        self._arctan(self.data, out = (rety.data, retz.data))
1298        return rety
1299
1300
1301    def sinhcosh(self):
1302        """ simultaneously computes s = sinh(x) and c = cosh(x) in UTP arithmetic"""
1303        rets = self.clone()
1304        retc = self.clone()
1305        self._sinhcosh(self.data, out = (rets.data, retc.data))
1306        return rets, retc
1307
1308    def sinh(self):
1309        """ computes y = sinh(x) in UTP arithmetic """
1310        retval = self.clone()
1311        tmp = self.clone()
1312        self._sinhcosh(self.data, out = (retval.data, tmp.data))
1313        return retval
1314
1315    def cosh(self):
1316        """ computes y = cosh(x) in UTP arithmetic """
1317        retval = self.clone()
1318        tmp = self.clone()
1319        self._sinhcosh(self.data, out = (tmp.data, retval.data))
1320        return retval
1321
1322    def tanh(self):
1323        """ computes y = tanh(x) in UTP arithmetic """
1324        retval = self.clone()
1325        tmp = self.clone()
1326        self._tanhsech2(self.data, out = (retval.data, tmp.data))
1327        return retval
1328
1329    def sign(self):
1330        """ computes y = sign(x) in UTP arithmetic"""
1331        retval = self.clone()
1332        self._sign(self.data, out = retval.data)
1333        return retval
1334
1335    def abs(self):
1336        """ computes y = sign(x) in UTP arithmetic"""
1337        return self.__abs__()
1338
1339    @classmethod
1340    def pb_sign(cls, ybar, x, y, out=None):
1341        """ computes bar y dy = bar x dx in UTP arithmetic"""
1342        if out is None:
1343            D,P = x.data.shape[:2]
1344            xbar = x.zeros_like()
1345        else:
1346            xbar, = out
1347        cls._pb_sign(ybar.data, x.data, y.data, out = xbar.data)
1348        return out
1349
1350
1351    def __abs__(self):
1352        """ absolute value of polynomials
1353
1354        FIXME: theory tells us to check first coefficient if the zero'th coefficient is zero
1355        """
1356        # check if zero order coeff is smaller than 0
1357        tmp = self.data[0] < 0
1358        retval = self.clone()
1359        retval.data *= (-1)**tmp
1360
1361        return retval
1362
1363    def fabs(self):
1364        return self.__abs__()
1365
1366    def __neg__(self):
1367        return self.__class__.neg(self)
1368
1369    def __lt__(self, other):
1370        if isinstance(other,self.__class__):
1371            return numpy.all(self.data[0,...] < other.data[0,...])
1372        else:
1373            return numpy.all(self.data[0,...] < other)
1374
1375    def __le__(self, other):
1376        if isinstance(other,self.__class__):
1377            return numpy.all(self.data[0,...] <= other.data[0,...])
1378        else:
1379            return numpy.all(self.data[0,...] <= other)
1380
1381    def __gt__(self, other):
1382        if isinstance(other,self.__class__):
1383            return numpy.all(self.data[0,...] > other.data[0,...])
1384        else:
1385            return numpy.all(self.data[0,...] > other)
1386
1387    def __ge__(self, other):
1388        if isinstance(other,self.__class__):
1389            return numpy.all(self.data[0,...] >= other.data[0,...])
1390        else:
1391            return numpy.all(self.data[0,...] >= other)
1392
1393    def __eq__(self, other):
1394        if isinstance(other,self.__class__):
1395            return numpy.all(self.data[0,...] == other.data[0,...])
1396        else:
1397            return numpy.all(self.data[0,...] == other)
1398
1399    @classmethod
1400    def neg(cls, x, out = None):
1401        return -1*x
1402
1403    @classmethod
1404    def add(cls, x, y , out = None):
1405        return x + y
1406
1407    @classmethod
1408    def sub(cls, x, y , out = None):
1409        return x - y
1410
1411    @classmethod
1412    def mul(cls, x, y , out = None):
1413        return x * y
1414
1415    @classmethod
1416    def div(cls, x, y , out = None):
1417        return x / y
1418
1419    @classmethod
1420    def multiply(cls, x, y , out = None):
1421        return x * y
1422
1423    @classmethod
1424    def max(cls, a, axis = None, out = None):
1425        if out is not None:
1426            raise NotImplementedError('should implement that')
1427
1428        if axis is not None:
1429            raise NotImplementedError('should implement that')
1430
1431        a_shp = a.data.shape
1432        out_shp = a_shp[:2]
1433        out = cls(cls.__zeros__(out_shp, dtype = a.data.dtype))
1434        cls._max( a.data, axis = axis, out = out.data)
1435        return out
1436
1437    @classmethod
1438    def argmax(cls, a, axis = None):
1439        if axis is not None:
1440            raise NotImplementedError('should implement that')
1441
1442        return cls._argmax( a.data, axis = axis)
1443
1444    @classmethod
1445    def trace(cls, x):
1446        D,P = x.data.shape[:2]
1447        retval = numpy.zeros((D,P), dtype=x.dtype)
1448        for d in range(D):
1449            for p in range(P):
1450                retval[d,p] = numpy.trace(x.data[d,p,...])
1451        return UTPM(retval)
1452
1453    @classmethod
1454    def det(cls, x):
1455        D,P = x.data.shape[:2]
1456        PIV,L,U = cls.lu2(x)
1457        return cls.piv2det(PIV) * cls.prod(cls.diag(U))
1458
1459    @classmethod
1460    def pb_det(cls, ybar, x, y, out = None):
1461        if out is None:
1462            xbar = x.zeros_like()
1463        else:
1464            xbar ,= out
1465
1466        PIV,L,U = cls.lu2(x)
1467        d   = cls.diag(U)
1468        z   = cls.prod(d)
1469        y   = cls.piv2det(PIV) * z
1470
1471        zbar   = cls.piv2det(PIV) * ybar
1472        dbar   = cls.pb_prod(zbar, d, z)
1473        PIVbar = PIV.zeros_like()
1474        Lbar   = L.zeros_like()
1475        Ubar   = cls.pb_diag(dbar, U, d)
1476        cls.pb_lu2(PIVbar, Lbar, Ubar, x, PIV, L, U, out=(xbar,))
1477        return xbar
1478
1479    @classmethod
1480    def logdet(cls, x):
1481        """
1482        compute logdet using algopy.lu2 as described in
1483        http://www.mathworks.com/matlabcentral/fileexchange/22026-safe-computation-of-logarithm-determinat-of-large-matrix/content/logdet.m
1484        """
1485        D,P,N = x.data.shape[:3]
1486        PIV,L,U = cls.lu2(x)
1487        du = cls.diag(U)
1488        su = cls.sign(du)
1489        au = cls.abs(du)
1490        c = cls.piv2det(PIV) * cls.prod(su)
1491        return cls.log(c) + cls.sum(cls.log(au))
1492
1493
1494    @classmethod
1495    def pb_logdet(cls, ybar, x, y, out = None):
1496        if out is None:
1497            xbar = x.zeros_like()
1498        else:
1499            xbar ,= out
1500
1501        PIV,L,U = cls.lu2(x)
1502        du = cls.diag(U)
1503        su = cls.sign(du)
1504        au = cls.abs(du)
1505        c  = cls.piv2det(PIV) * cls.prod(su)
1506        l  = cls.log(au)
1507        y  = cls.log(c) + cls.sum(l)
1508
1509        lbar    = cls.pb_sum(ybar, l, y, None, None, None)
1510        aubar   = cls.pb_log(lbar, au, l)
1511        dubar   = su * aubar
1512        PIVbar  = PIV.zeros_like()
1513        Lbar    = L.zeros_like()
1514        Ubar    = cls.pb_diag(dubar, U, du)
1515        cls.pb_lu2(PIVbar, Lbar, Ubar, x, PIV, L, U, out=(xbar,))
1516        return xbar
1517
1518    def FtoJT(self):
1519        """
1520        Combines several directional derivatives and combines them to a transposed Jacobian JT, i.e.
1521        x.data.shape = (D,P,shp)
1522        y = x.FtoJT()
1523        y.data.shape = (D-1, (P,1) + shp)
1524        """
1525        D,P = self.data.shape[:2]
1526        shp = self.data.shape[2:]
1527        return UTPM(self.data[1:,...].reshape((D-1,1) + (P,) + shp))
1528
1529    def JTtoF(self):
1530        """
1531        inverse operation of FtoJT
1532        x.data.shape = (D,1, P,shp)
1533        y = x.JTtoF()
1534        y.data.shape = (D+1, P, shp)
1535        """
1536        D = self.data.shape[0]
1537        P = self.data.shape[2]
1538        shp = self.data.shape[3:]
1539        tmp = numpy.zeros((D+1,P) + shp)
1540        tmp[0:D,...] = self.data.reshape((D,P) + shp)
1541        return UTPM(tmp)
1542
1543    def clone(self):
1544        """
1545        Returns a new UTPM instance with the same data.
1546
1547        `clone` is opposed to `copy` or `deepcopy` by calling the __init__ function.
1548
1549        Rationale:
1550            the __init__ function may have side effects that must be executed.
1551            Naming stems from the fact that a cloned animal is not an exact copy
1552            but built using the same information.
1553        """
1554        return UTPM(self.data.copy())
1555
1556    def copy(self):
1557        """ this method is equivalent to `clone`.
1558        It's there to allow generic programming because ndarrays do not have the clone method."""
1559        return self.clone()
1560
1561    def get_shape(self):
1562        return numpy.shape(self.data[0,0,...])
1563    shape = property(get_shape)
1564
1565
1566    def get_size(self):
1567        return numpy.size(self.data[0,0,...])
1568    size = property(get_size)
1569
1570    def get_ndim(self):
1571        return numpy.ndim(self.data[0,0,...])
1572    ndim = property(get_ndim)
1573
1574    def __len__(self):
1575        return self.shape[0]
1576
1577    def reshape(self, dims):
1578        return UTPM(self.data.reshape(self.data.shape[0:2] + dims))
1579
1580    def get_transpose(self):
1581        return self.transpose()
1582    def set_transpose(self,x):
1583        raise NotImplementedError('???')
1584    T = property(get_transpose, set_transpose)
1585
1586    def transpose(self, axes = None):
1587        return UTPM( UTPM._transpose(self.data))
1588
1589    def transpose(self, axes = None):
1590        return UTPM( UTPM._transpose(self.data))
1591    def transpose(self, axes=None):
1592        return UTPM(UTPM._transpose(self.data, axes=axes))
1593
1594    def conj(self):
1595        return self.conjugate()
1596
1597    def conjugate(self):
1598        return UTPM(numpy.conjugate(self.data))
1599
1600    def get_owndata(self):
1601        return self.data.flags['OWNDATA']
1602
1603    owndata = property(get_owndata)
1604
1605    def set_zero(self):
1606        self.data[...] = 0.
1607        return self
1608
1609    @classmethod
1610    def zeros(cls, shape, dtype=None):
1611        if not isinstance(dtype, self.__class__):
1612            raise NotImplementedError('dtype must be a UTPM object')
1613        D,P = dtype.data.shape[:2]
1614
1615        if isinstance(shape, int):
1616            shape = (shape,)
1617
1618        return self.__class__(numpy.zeros((D,P) + shape))
1619
1620    def zeros_like(self):
1621        return self.__class__(numpy.zeros_like(self.data))
1622
1623    def ones_like(self):
1624        data = numpy.zeros_like(self.data)
1625        data[0,...] = 1.
1626        return self.__class__(data)
1627
1628    def shift(self, s, out = None):
1629        """
1630        shifting coefficients [x0,x1,x2,x3] s positions
1631
1632        e.g. shift([x0,x1,x2,x3], -1) = [x1,x2,x3,0]
1633             shift([x0,x1,x2,x3], +1) = [0,x0,x1,x2]
1634        """
1635
1636        if out is None:
1637            out = self.zeros_like()
1638
1639        if s <= 0:
1640            out.data[:s,...] = self.data[-s:,...]
1641
1642        else:
1643            out.data[s:,...] = self.data[:-s,...]
1644
1645        return out
1646
1647
1648    def __str__(self):
1649        return str(self.data)
1650
1651    def __repr__(self):
1652        return 'UTPM(' + self.__str__() + ')'
1653
1654    @classmethod
1655    def pb_zeros(cls, *args, **kwargs):
1656        pass
1657
1658    @classmethod
1659    def tril(cls, x, k=0, out = None):
1660        out = x.zeros_like()
1661        D,P = out.data.shape[:2]
1662        # print D,P
1663        for d in range(D):
1664            for p in range(P):
1665                out.data[d,p] = numpy.tril(x.data[d,p], k=k)
1666
1667        return out
1668
1669    @classmethod
1670    def triu(cls, x, k=0, out = None):
1671        out = x.zeros_like()
1672        D,P = out.data.shape[:2]
1673        # print D,P
1674        for d in range(D):
1675            for p in range(P):
1676                out.data[d,p] = numpy.triu(x.data[d,p], k=k)
1677
1678        return out
1679
1680    @classmethod
1681    def init_jacobian(cls, x, dtype=None):
1682        """ initializes this UTPM instance to compute the Jacobian,
1683
1684        it is possible to force the dtype to a certain dtype,
1685        if no dtype is provided, the dtype is inferred from x
1686        """
1687
1688        # print 'called init_jacobian'
1689        x = numpy.asarray(x)
1690
1691        if dtype is None:
1692            # try to infer the dtype from x
1693            dtype= x.dtype
1694
1695            if dtype==int:
1696                dtype=float
1697
1698
1699        shp = numpy.shape(x)
1700        data = numpy.zeros(numpy.hstack( (2, numpy.size(x)) +  shp), dtype=dtype)
1701        data[0] = x
1702
1703        x0 = x.ravel()[0]
1704        # print 'type(x0)=',type(x0)
1705        if isinstance(x0, cls):
1706            data[1] = data[0].copy()
1707            data[1] *= 0
1708            data[1] += numpy.eye(numpy.size(x))
1709
1710        else:
1711            data[1,:].flat = numpy.eye(numpy.size(x))
1712
1713        return cls(data)
1714
1715
1716
1717    @classmethod
1718    def extract_jacobian(cls, x):
1719        """ extracts the Jacobian from a UTPM instance
1720        if x.ndim == 1 it is equivalent to the gradient
1721        """
1722        retval = x.data[1,...].transpose([i for i in range(1,x.data[1,...].ndim)] + [0])
1723
1724        # print 'x.data.dtype=',x.data.dtype
1725        # print 'x.data=',x.data
1726
1727
1728        x0 = retval.ravel()[0]
1729        if isinstance(x0, cls):
1730            # print 'call as_utpm'
1731            retval = cls.as_utpm(retval)
1732
1733        return retval
1734
1735    @classmethod
1736    def init_jac_vec(cls, x, v, dtype=None):
1737        """ initializes this UTPM instance to compute the Jacobian vector product J v,
1738
1739        it is possible to force the dtype to a certain dtype,
1740        if no dtype is provided, the dtype is inferred from x
1741        """
1742
1743        x = numpy.asarray(x)
1744
1745        if dtype is None:
1746            # try to infer the dtype from x
1747            dtype= x.dtype
1748
1749            if dtype==int:
1750                dtype=float
1751
1752
1753        shp = numpy.shape(x)
1754        data = numpy.zeros(numpy.hstack( [2, 1, shp]), dtype=dtype)
1755        data[0,0] = x
1756        data[1,0] = v
1757        return cls(data)
1758
1759    @classmethod
1760    def extract_jac_vec(cls, x):
1761        """ extracts the Jacobian vector product from a UTPM instance
1762        if x.ndim == 1 it is equivalent to the gradient
1763        """
1764        return x.data[1,...].transpose([i for i in range(1,x.data[1,...].ndim)] + [0])[:,0]
1765
1766
1767    @classmethod
1768    def init_tensor(cls, d, x):
1769        """ initializes this UTPM instance to compute the dth degree derivative tensor,
1770        e.g. d=2 is the Hessian
1771        """
1772
1773        import algopy.exact_interpolation as exint
1774        x = numpy.asarray(x)
1775
1776        if x.ndim != 1:
1777            raise NotImplementedError('non vector inputs are not implemented yet')
1778
1779        N = numpy.size(x)
1780        Gamma, rays = exint.generate_Gamma_and_rays(N,d)
1781
1782        data = numpy.zeros(numpy.hstack([d+1,rays.shape]))
1783        data[0] = x
1784        data[1] = rays
1785        return cls(data)
1786
1787    @classmethod
1788    def extract_tensor(cls, N, y, as_full_matrix = True):
1789        """ extracts the Hessian of shape (N,N) from the UTPM instance y
1790        """
1791
1792        import algopy.exact_interpolation as exint
1793        d = y.data.shape[0]-1
1794        Gamma, rays = exint.generate_Gamma_and_rays(N,d)
1795        tmp = numpy.dot(Gamma,y.data[d])
1796
1797        if as_full_matrix == False:
1798            return tmp
1799
1800        else:
1801            retval = numpy.zeros((N,N))
1802            mi = exint.generate_multi_indices(N,d)
1803            pos = exint.convert_multi_indices_to_pos(mi)
1804
1805            for ni in range(mi.shape[0]):
1806                # print 'ni=',ni, mi[ni], pos[ni], tmp[ni]
1807                for perm in exint.generate_permutations(list(pos[ni])):
1808                    retval[perm[0],perm[1]] = tmp[ni]*numpy.max(mi[ni])
1809
1810            return retval
1811
1812
1813    @classmethod
1814    def init_hessian(cls, x):
1815        """ initializes this UTPM instance to compute the Hessian
1816        """
1817
1818        x = numpy.ravel(x)
1819
1820        # generate directions
1821        N = x.size
1822        M = (N*(N+1))//2
1823        L = (N*(N-1))//2
1824        S = numpy.zeros((N,M), dtype=x.dtype)
1825
1826        s = 0
1827        i = 0
1828        for n in range(1,N+1):
1829            S[-n:,s:s+n] = numpy.eye(n)
1830            S[-n,s:s+n] = numpy.ones(n)
1831            s+=n
1832            i+=1
1833        S = S[::-1].T
1834
1835        data = numpy.zeros(numpy.hstack([3,S.shape]), dtype=x.dtype)
1836        data[0] = x
1837        data[1] = S
1838        return cls(data)
1839
1840    @classmethod
1841    def extract_hessian(cls, N, y, as_full_matrix = True, use_mpmath=False):
1842        """ extracts the Hessian of shape (N,N) from the UTPM instance y
1843        """
1844
1845        if use_mpmath:
1846            import mpmath
1847            mpmath.dps = 50
1848
1849
1850        H = numpy.zeros((N,N),dtype=y.data.dtype)
1851        for n in range(N):
1852            for m in range(n):
1853                a =  sum(range(n+1))
1854                b =  sum(range(m+1))
1855                k =  sum(range(n+2)) - m - 1
1856                #print 'k,a,b=', k,a,b
1857                if n!=m:
1858
1859                    if use_mpmath:
1860                        tmp = (mpmath.mpf(y.data[2,k]) - mpmath.mpf(y.data[2,a]) - mpmath.mpf(y.data[2,b]))
1861                    else:
1862                        tmp = (y.data[2,k] - y.data[2,a] - y.data[2,b])
1863
1864                    H[m,n]= H[n,m]= tmp
1865            a =  sum(range(n+1))
1866            H[n,n] = 2*y.data[2,a]
1867        return H
1868
1869    @classmethod
1870    def init_hess_vec(cls, x, v, dtype=None):
1871        """ initializes this UTPM instance to compute the Hessian vector product H v,
1872
1873        it is possible to force the dtype to a certain dtype,
1874        if no dtype is provided, the dtype is inferred from x
1875        """
1876
1877        x = numpy.asarray(x)
1878        if x.ndim != 1:
1879            raise NotImplementedError(
1880                    'non vector inputs are not implemented yet')
1881
1882        if dtype is None:
1883            # try to infer the dtype from x
1884            dtype= x.dtype
1885
1886            if dtype==int:
1887                dtype=float
1888
1889        N = numpy.size(x)
1890        P = 2*N + 1
1891        ident = numpy.identity(N)
1892
1893        # Construct the UTPM data.
1894        data = numpy.zeros((3, P, N), dtype=dtype)
1895        data[0] = x
1896        for n in range(N):
1897            data[1, n, :] = ident[n]
1898            data[1, n+N, :] = v + ident[n]
1899        data[1, 2*N, :] = v
1900
1901        return cls(data)
1902
1903    @classmethod
1904    def extract_hess_vec(cls, N, x):
1905        """ extracts the Hessian-vector product from a UTPM instance
1906        """
1907        Hv = numpy.zeros(N)
1908        for n in range(N):
1909            Hv[n] = -x.data[2, n] + x.data[2, n+N] - x.data[2, 2*N]
1910        return Hv
1911
1912    @classmethod
1913    def dot(cls, x, y, out = None):
1914        """
1915        out = dot(x,y)
1916        """
1917
1918        if isinstance(x, UTPM) and isinstance(y, UTPM):
1919            x_shp = x.data.shape
1920            y_shp = y.data.shape
1921
1922            assert x_shp[:2] == y_shp[:2]
1923
1924            if  len(y_shp[2:]) == 1:
1925                out_shp = x_shp[:-1]
1926
1927            else:
1928                out_shp = x_shp[:2] + x_shp[2:-1] + y_shp[2:][:-2] + y_shp[2:][-1:]
1929
1930            out = cls(cls.__zeros__(out_shp, dtype=numpy.promote_types(x.data.dtype, y.data.dtype)))
1931            cls._dot( x.data, y.data, out = out.data)
1932
1933        elif isinstance(x, UTPM) and not isinstance(y, UTPM):
1934            x_shp = x.data.shape
1935            y_shp = y.shape
1936
1937            if  len(y_shp) == 1:
1938                out_shp = x_shp[:-1]
1939
1940            else:
1941                out_shp = x_shp[:2] + x_shp[2:-1] + y_shp[:-2] + y_shp[-1:]
1942
1943            out = cls(cls.__zeros__(out_shp, dtype=numpy.promote_types(x.data.dtype, y.dtype)))
1944            cls._dot_non_UTPM_y(x.data, y, out = out.data)
1945
1946        elif not isinstance(x, UTPM) and isinstance(y, UTPM):
1947            x_shp = x.shape
1948            y_shp = y.data.shape
1949
1950            if  len(y_shp[2:]) == 1:
1951                out_shp = y_shp[:2] + x_shp[:-1]
1952
1953            else:
1954                out_shp = y_shp[:2] + x_shp[:-1] + y_shp[2:][:-2] + y_shp[2:][-1:]
1955
1956            out = cls(cls.__zeros__(out_shp, dtype=numpy.promote_types(x.dtype, y.data.dtype)))
1957            cls._dot_non_UTPM_x(x, y.data, out = out.data)
1958
1959
1960        else:
1961            raise NotImplementedError('should implement that')
1962
1963        return out
1964
1965    @classmethod
1966    def outer(cls, x, y, out = None):
1967        """
1968        out = outer(x,y)
1969        """
1970
1971        if isinstance(x, UTPM) and isinstance(y, UTPM):
1972            x_shp = x.data.shape
1973            y_shp = y.data.shape
1974
1975            assert x_shp[:2] == y_shp[:2]
1976            assert len(y_shp[2:]) == 1
1977
1978            out_shp = x_shp + x_shp[-1:]
1979            out = cls(cls.__zeros__(out_shp, dtype = x.data.dtype))
1980            cls._outer( x.data, y.data, out = out.data)
1981
1982        elif isinstance(x, UTPM) and isinstance(y, numpy.ndarray):
1983            x_shp = x.data.shape
1984            out_shp = x_shp + x_shp[-1:]
1985            out = cls(cls.__zeros__(out_shp, dtype = x.data.dtype))
1986            cls._outer_non_utpm_y( x.data, y, out = out.data)
1987
1988        elif isinstance(x, numpy.ndarray) and isinstance(y, UTPM):
1989            y_shp = y.data.shape
1990            out_shp = y_shp + y_shp[-1:]
1991            out = cls(cls.__zeros__(out_shp, dtype = y.data.dtype))
1992            cls._outer_non_utpm_x( x, y.data, out = out.data)
1993
1994        else:
1995            raise NotImplementedError('this operation is not supported')
1996
1997        return out
1998
1999    @classmethod
2000    def pb_outer(cls, zbar, x, y, z, out = None):
2001        if out is None:
2002            D,P = y.data.shape[:2]
2003            xbar = x.zeros_like()
2004            ybar = y.zeros_like()
2005
2006        else:
2007            xbar, ybar = out
2008
2009        cls._outer_pullback(zbar.data, x.data, y.data, z.data, out = (xbar.data, ybar.data))
2010        return (xbar,ybar)
2011
2012
2013    @classmethod
2014    def inv(cls, A, out = None):
2015        if out is None:
2016            out = cls(cls.__zeros__(A.data.shape, dtype = A.data.dtype))
2017        else:
2018            raise NotImplementedError('')
2019
2020        cls._inv(A.data,(out.data,))
2021        return out
2022        # # tc[0] element
2023        # for p in range(P):
2024        #     out.data[0,p,:,:] = numpy.linalg.inv(A.data[0,p,:,:])
2025
2026        # # tc[d] elements
2027        # for d in range(1,D):
2028        #     for p in range(P):
2029        #         for c in range(1,d+1):
2030        #             out.data[d,p,:,:] += numpy.dot(A.data[c,p,:,:], out.data[d-c,p,:,:],)
2031        #         out.data[d,p,:,:] =  numpy.dot(-out.data[0,p,:,:], out.data[d,p,:,:],)
2032        # return out
2033
2034    @classmethod
2035    def solve(cls, A, x, out = None):
2036        """
2037        solves for y in: A y = x
2038
2039        """
2040        if isinstance(A, UTPM) and isinstance(x, UTPM):
2041            A_shp = A.data.shape
2042            x_shp = x.data.shape
2043
2044            assert A_shp[:2] == x_shp[:2]
2045            if A_shp[2] != x_shp[2]:
2046                print(ValueError('A.data.shape = %s does not match x.data.shape = %s'%(str(A_shp), str(x_shp))))
2047
2048            if len(x_shp) == 3:
2049                raise ValueError("require x.data.shape=(D,P,M,K) and A.data.shape=(D,P,M,N) but provided x.data.shape(D,P,M)=%s"%str(x.data.shape))
2050
2051            D, P, M = A_shp[:3]
2052
2053            if out is None:
2054                dtype = numpy.promote_types(A.data.dtype, x.data.dtype)
2055                out = cls(cls.__zeros__((D,P,M) + x_shp[3:], dtype=dtype))
2056
2057            UTPM._solve(A.data, x.data, out = out.data)
2058
2059        elif not isinstance(A, UTPM) and isinstance(x, UTPM):
2060            A_shp = numpy.shape(A)
2061            x_shp = numpy.shape(x.data)
2062            M = A_shp[0]
2063            D,P = x_shp[:2]
2064            dtype = numpy.promote_types(A.dtype, x.data.dtype)
2065            out = cls(cls.__zeros__((D,P,M) + x_shp[3:], dtype=dtype))
2066            cls._solve_non_UTPM_A(A, x.data, out = out.data)
2067
2068        elif isinstance(A, UTPM) and not isinstance(x, UTPM):
2069            A_shp = numpy.shape(A.data)
2070            x_shp = numpy.shape(x)
2071            D,P,M = A_shp[:3]
2072            dtype = numpy.promote_types(A.data.dtype, x.dtype)
2073            out = cls(cls.__zeros__((D,P,M) + x_shp[1:], dtype=dtype))
2074            cls._solve_non_UTPM_x(A.data, x, out = out.data)
2075
2076        else:
2077            raise NotImplementedError('should implement that')
2078
2079        return out
2080
2081    @classmethod
2082    def cholesky(cls, A, out = None):
2083        if out is None:
2084            out = A.zeros_like()
2085
2086        cls._cholesky(A.data, out.data)
2087        return out
2088
2089    @classmethod
2090    def pb_cholesky(cls, Lbar, A, L, out = None):
2091        if out is None:
2092            D,P = A.data.shape[:2]
2093            Abar = A.zeros_like()
2094
2095        else:
2096            Abar, = out
2097
2098        cls._pb_cholesky(Lbar.data, A.data, L.data, out = Abar.data)
2099        return Abar
2100
2101    @classmethod
2102    def lu_factor(cls, A, out = None):
2103        """
2104        univariate Taylor arithmetic of scipy.linalg.lu_factor
2105        """
2106        D,P,N = A.data.shape[:3]
2107
2108        if out is None:
2109            LU  = A.zeros_like()
2110            PIV = cls(numpy.zeros((D,P,N))) # permutation
2111
2112        for p in range(P):
2113            # D = 0
2114            lu, piv = scipy.linalg.lu_factor(A.data[0,p])
2115            w = algopy.utils.piv2mat(piv)
2116
2117            LU.data[0,p] = lu
2118            PIV.data[0,p] = piv
2119
2120            L0 = numpy.tril(lu, -1) + numpy.eye(N)
2121            U0 = numpy.triu(lu, 0)
2122
2123            # allocate temporary storage
2124            L0inv = numpy.linalg.inv(L0)
2125            U0inv = numpy.linalg.inv(U0)
2126            dF    = numpy.zeros((N,N),dtype=float)
2127
2128            for d in range(1,D):
2129                dF *= 0
2130                for i in range(1,d):
2131                    tmp1 = numpy.tril(LU.data[d-i,p], -1)
2132                    tmp2 = numpy.triu(LU.data[i,p], 0)
2133                    dF -= numpy.dot(tmp1, tmp2)
2134                dF += numpy.dot(w.T, A.data[d,p])
2135                dF = numpy.dot(L0inv, numpy.dot(dF, U0inv))
2136
2137                Ud = numpy.dot(numpy.triu(dF, 0), U0)
2138                Ld = numpy.dot(L0, numpy.tril(dF, -1))
2139
2140                LU.data[d, p] += Ud
2141                LU.data[d, p] += Ld
2142
2143        return LU, PIV
2144
2145    @classmethod
2146    def lu(cls, A, out = None):
2147        """
2148        univariate Taylor arithmetic of scipy.linalg.lu
2149        """
2150        D,P,N = A.data.shape[:3]
2151
2152        if out is None:
2153            L = A.zeros_like()
2154            U = A.zeros_like()
2155            W = A.zeros_like() # permutation matrix
2156
2157
2158        for p in range(P):
2159            # D = 0
2160            w,l,u = scipy.linalg.lu(A.data[0,p])
2161            W.data[0,p] = w
2162            L.data[0,p] = l
2163            U.data[0,p] = u
2164
2165            # allocate temporary storage
2166            L0inv = numpy.linalg.inv(L.data[0,p])
2167            U0inv = numpy.linalg.inv(U.data[0,p])
2168            dF    = numpy.zeros((N,N),dtype=float)
2169
2170            for d in range(1,D):
2171                dF *= 0
2172                for i in range(1,d):
2173                    dF -= numpy.dot(L.data[d-i,p], U.data[i,p])
2174                dF += numpy.dot(w.T, A.data[d,p])
2175                dF = numpy.dot(L0inv, numpy.dot(dF, U0inv))
2176
2177                U.data[d,p] = numpy.dot(numpy.triu(dF, 0), U.data[0,p])
2178                L.data[d,p] = numpy.dot(L.data[0,p], numpy.tril(dF, -1))
2179
2180        return W, L, U
2181
2182    @classmethod
2183    def pb_lu(cls, Wbar, Lbar, Ubar, A, W, L, U, out = None):
2184        D,P,M,N = numpy.shape(A.data)
2185
2186        if out is None:
2187            Abar = A.zeros_like()
2188
2189        else:
2190            Abar, = out
2191
2192        v1 = cls.tril(cls.dot(L.T, Lbar), -1) + cls.triu(cls.dot(Ubar, U.T), 0)
2193        v2 = cls.solve(L.T, v1)
2194        v3 = cls.solve(U, v2.T).T
2195
2196        Abar += cls.dot(W, v3)
2197
2198        return Abar
2199
2200    @classmethod
2201    def lu2(cls, A, out = None):
2202        """
2203        univariate Taylor arithmetic of scipy.linalg.lu_factor
2204        but returns piv, L, U = lu2(A)
2205        """
2206        D,P,N = A.data.shape[:3]
2207
2208        if out is None:
2209            PIV = cls(numpy.zeros((D,P,N), dtype=int)) # pivot elements
2210            L = A.zeros_like()
2211            U = A.zeros_like()
2212
2213        for p in range(P):
2214            # D = 0
2215            lu, piv = scipy.linalg.lu_factor(A.data[0,p])
2216            w = algopy.utils.piv2mat(piv)
2217            L.data[0,p] = numpy.tril(lu, -1) + numpy.eye(N)
2218            U.data[0,p] = numpy.triu(lu, 0)
2219            PIV.data[0,p] = piv
2220
2221            # allocate temporary storage
2222            L0inv = numpy.linalg.inv(L.data[0,p])
2223            U0inv = numpy.linalg.inv(U.data[0,p])
2224            dF    = numpy.zeros((N,N),dtype=float)
2225
2226            for d in range(1,D):
2227                dF *= 0
2228                for i in range(1,d):
2229                    dF -= numpy.dot(L.data[d-i,p], U.data[i,p])
2230                dF += numpy.dot(w.T, A.data[d,p])
2231                dF = numpy.dot(L0inv, numpy.dot(dF, U0inv))
2232
2233                U.data[d,p] = numpy.dot(numpy.triu(dF, 0), U.data[0,p])
2234                L.data[d,p] = numpy.dot(L.data[0,p], numpy.tril(dF, -1))
2235
2236        return PIV, L, U
2237
2238    @classmethod
2239    def pb_lu2(cls, PIVbar, Lbar, Ubar, A, PIV, L, U, out = None):
2240        D,P,M,N = numpy.shape(A.data)
2241
2242        if out is None:
2243            Abar = A.zeros_like()
2244
2245        else:
2246            Abar, = out
2247
2248        v1 = cls.tril(cls.dot(L.T, Lbar), -1) + cls.triu(cls.dot(Ubar, U.T), 0)
2249        v2 = cls.solve(L.T, v1)
2250        v3 = cls.solve(U, v2.T).T
2251
2252        W = cls.piv2mat(PIV)
2253        Abar += cls.dot(W, v3)
2254
2255        return Abar
2256
2257
2258    @classmethod
2259    def piv2mat(cls, piv):
2260        D,P,N = piv.data.shape
2261        W = cls(numpy.zeros((D,P,N,N)))
2262        for p in range(P):
2263            W.data[0,p] = algopy.utils.piv2mat(piv.data[0,p])
2264
2265        return W
2266
2267    @classmethod
2268    def piv2det(cls, piv):
2269        D,P,N = piv.data.shape
2270        det = cls(numpy.zeros((D,P)))
2271        for p in range(P):
2272            det.data[0,p] = algopy.utils.piv2det(piv.data[0,p])
2273        return det
2274
2275    @classmethod
2276    def pb_Id(cls, ybar, x, y, out = None):
2277        return out
2278
2279    @classmethod
2280    def pb_neg(cls, ybar, x, y, out = None):
2281        if out is None:
2282            xbar = x.zeros_like()
2283
2284        else:
2285            xbar, = out
2286
2287        xbar -= ybar
2288        return xbar
2289
2290    @classmethod
2291    def pb___neg__(cls, ybar, x, y, out = None):
2292        return cls.pb_neg(ybar, x, y, out = out)
2293
2294    @classmethod
2295    def pb___add__(cls, zbar, x, y , z, out = None):
2296        return cls.pb_add(zbar, x, y , z, out = out)
2297
2298    @classmethod
2299    def pb___sub__(cls, zbar, x, y , z, out = None):
2300        return cls.pb_sub(zbar, x, y , z, out = out)
2301
2302    @classmethod
2303    def pb___mul__(cls, zbar, x, y , z, out = None):
2304        return cls.pb_mul(zbar, x, y , z, out = out)
2305
2306    @classmethod
2307    def pb___truediv__(cls, zbar, x, y , z, out = None):
2308        return cls.pb_truediv(zbar, x, y , z, out = out)
2309
2310    @classmethod
2311    def pb_add(cls, zbar, x, y, z, out=None):
2312        if out is None:
2313            D, P = y.data.shape[:2]
2314            xbar = x.zeros_like()
2315            ybar = y.zeros_like()
2316
2317        else:
2318            xbar, ybar = out
2319
2320        if isinstance(xbar, UTPM):
2321            xbar2, zbar2 = cls.broadcast(xbar, zbar)
2322
2323            # print 'xbar = ', xbar
2324            # print 'zbar = ', zbar
2325            # print 'xbar2 = ', xbar2
2326            # print 'zbar2 = ', zbar2
2327
2328            # print 'xbar2.data.strides = ', xbar2.data.strides
2329            # print 'zbar2.data.strides = ', zbar2.data.strides
2330            # print 'xbar2 + zbar2 = ', xbar2 + zbar2
2331
2332            workaround_strides_function(xbar2, zbar2, operator.iadd)
2333            # xbar2[...] = xbar2 + zbar2
2334
2335            # print 'after update'
2336            # print 'xbar2 =\n', xbar2
2337            # print 'xbar =\n', xbar
2338
2339        if isinstance(ybar, UTPM):
2340            ybar2, zbar2 = cls.broadcast(ybar, zbar)
2341            workaround_strides_function(ybar2, zbar2, operator.iadd)
2342            # ybar2 += zbar2
2343        # print 'ybar2.data.shape=',ybar2.data.shape
2344
2345
2346        return (xbar, ybar)
2347
2348
2349    @classmethod
2350    def pb___iadd__(cls, zbar, x, y, z, out = None):
2351        # FIXME: this is a workaround/hack, review this
2352        x = x.copy()
2353        return cls.pb___add__(zbar, x, y, z, out = out)
2354        # if out is None:
2355            # D,P = y.data.shape[:2]
2356            # xbar = cls(cls.__zeros__(x.data.shape))
2357            # ybar = cls(cls.__zeros__(y.data.shape))
2358
2359        # else:
2360            # xbar, ybar = out
2361
2362        # xbar = zbar
2363        # ybar += zbar
2364
2365
2366        # return xbar, ybar
2367
2368    @classmethod
2369    def pb_sub(cls, zbar, x, y , z, out = None):
2370        if out is None:
2371            D,P = y.data.shape[:2]
2372            xbar = x.zeros_like()
2373            ybar = y.zeros_like()
2374
2375        else:
2376            xbar, ybar = out
2377
2378        if isinstance(x, UTPM):
2379            xbar2,zbar2 = cls.broadcast(xbar, zbar)
2380            workaround_strides_function(xbar2, zbar2, operator.iadd)
2381            # xbar2 += zbar2
2382
2383        if isinstance(y, UTPM):
2384            ybar2,zbar2 = cls.broadcast(ybar, zbar)
2385            workaround_strides_function(ybar2, zbar2, operator.isub)
2386            # ybar2 -= zbar2
2387
2388        return (xbar,ybar)
2389
2390
2391    @classmethod
2392    def pb_mul(cls, zbar, x, y , z, out = None):
2393
2394        if isinstance(x, UTPM) and isinstance(y, UTPM):
2395            if out is None:
2396                D,P = z.data.shape[:2]
2397                xbar = x.zeros_like()
2398                ybar = y.zeros_like()
2399
2400            else:
2401                xbar, ybar = out
2402
2403            xbar2, tmp = cls.broadcast(xbar, zbar)
2404            ybar2, tmp = cls.broadcast(ybar, zbar)
2405
2406            # xbar2 += zbar * y
2407            workaround_strides_function(xbar2, zbar * y, operator.iadd)
2408            # ybar2 += zbar * x
2409            workaround_strides_function(ybar2, zbar * x, operator.iadd)
2410
2411            return (xbar, ybar)
2412
2413        elif isinstance(x, UTPM):
2414            if out is None:
2415                D, P = z.data.shape[:2]
2416                xbar = x.zeros_like()
2417                ybar = None
2418
2419            else:
2420                xbar, ybar = out
2421
2422            xbar2, tmp = cls.broadcast(xbar, zbar)
2423
2424            workaround_strides_function(xbar2, zbar * y, operator.iadd)
2425            # xbar2 += zbar * y
2426
2427            return (xbar, ybar)
2428
2429        elif isinstance(y, UTPM):
2430            if out is None:
2431                D, P = z.data.shape[:2]
2432                xbar = None
2433                ybar = y.zeros_like()
2434
2435            else:
2436                xbar, ybar = out
2437
2438            ybar2, tmp = cls.broadcast(ybar, zbar)
2439
2440            workaround_strides_function(xbar2, zbar * x, operator.iadd)
2441            # ybar2 += zbar * x
2442
2443            return (xbar, ybar)
2444
2445        else:
2446            raise NotImplementedError('not implemented')
2447
2448    @classmethod
2449    def pb_truediv(cls, zbar, x, y, z, out=None):
2450
2451        if isinstance(x, UTPM) and isinstance(y, UTPM):
2452
2453            if out is None:
2454                D,P = y.data.shape[:2]
2455                xbar = x.zeros_like()
2456                ybar = y.zeros_like()
2457
2458            else:
2459                xbar, ybar = out
2460
2461            x2, y2 = cls.broadcast(x, y)
2462
2463            xbar2, tmp = cls.broadcast(xbar, zbar)
2464            ybar2, tmp = cls.broadcast(ybar, zbar)
2465
2466            tmp = zbar.clone()
2467            # tmp /= y2
2468            workaround_strides_function(tmp, y2, operator.itruediv)
2469            # xbar2 += tmp
2470            workaround_strides_function(xbar2, tmp, operator.iadd)
2471            # tmp *= z
2472            workaround_strides_function(tmp, z, operator.imul)
2473            # ybar2 -= tmp
2474            workaround_strides_function(ybar2, tmp, operator.isub)
2475
2476            return (xbar, ybar)
2477
2478        elif isinstance(x, UTPM):
2479
2480            if out is None:
2481                D, P = z.data.shape[:2]
2482                xbar = x.zeros_like()
2483                ybar = None
2484
2485            else:
2486                xbar, ybar = out
2487
2488            xbar2, tmp = cls.broadcast(xbar, zbar)
2489
2490            # tmp /= y2
2491            # xbar2 += tmp
2492            workaround_strides_function(xbar2, zbar / y, operator.iadd)
2493
2494            return (xbar, ybar)
2495
2496        elif isinstance(y, UTPM):
2497
2498            if out is None:
2499                D, P = z.data.shape[:2]
2500                xbar = None
2501                ybar = y.zeros_like()
2502
2503            else:
2504                xbar, ybar = out
2505
2506            ybar2, tmp = cls.broadcast(ybar, zbar)
2507            workaround_strides_function(ybar2, zbar / y * z, operator.isub)
2508
2509            return (xbar, ybar)
2510
2511    @classmethod
2512    def broadcast(cls, x,y):
2513        """
2514        this is the UTPM equivalent to numpy.broadcast_arrays
2515        """
2516        if numpy.isscalar(x) or isinstance(x,numpy.ndarray):
2517            return x,y
2518
2519        if numpy.isscalar(y) or isinstance(y,numpy.ndarray):
2520            return x,y
2521
2522        # broadcast xbar and ybar
2523        x2_data, y2_data = cls._broadcast_arrays(x.data,y.data)
2524
2525        x2 = UTPM(x2_data)
2526        y2 = UTPM(y2_data)
2527        return x2, y2
2528
2529    @classmethod
2530    def pb_dot(cls, zbar, x, y, z, out = None):
2531        if out is None:
2532            D,P = y.data.shape[:2]
2533            xbar = x.zeros_like()
2534            ybar = y.zeros_like()
2535
2536        else:
2537            xbar, ybar = out
2538
2539        # print 'x = ', type(x)
2540        # print 'y = ',type(y)
2541        # print 'z = ',type(z)
2542
2543        # print 'xbar = ', type(xbar)
2544        # print 'ybar = ',type(ybar)
2545        # print 'zbar = ',type(zbar)
2546
2547        if not isinstance(x,cls):
2548            D,P = z.data.shape[:2]
2549            tmp = cls(numpy.zeros((D,P) + x.shape,dtype=z.data.dtype))
2550            tmp[...] = x[...]
2551            x = tmp
2552
2553        if not isinstance(xbar,cls):
2554            xbar = cls(numpy.zeros((D,P) + x.shape,dtype=z.data.dtype))
2555
2556        if not isinstance(y,cls):
2557            D,P = xbar.data.shape[:2]
2558            tmp = cls(numpy.zeros((D,P) + y.shape,dtype=z.data.dtype))
2559            tmp[...] = y[...]
2560            y = tmp
2561
2562        if not isinstance(ybar,cls):
2563            ybar = cls(numpy.zeros((D,P) + y.shape,dtype=z.data.dtype))
2564
2565        cls._dot_pullback(zbar.data, x.data, y.data, z.data, out = (xbar.data, ybar.data))
2566        return (xbar,ybar)
2567
2568    @classmethod
2569    def pb_reshape(cls, ybar, x, newshape, y, out = None):
2570        if out is None:
2571            D,P = y.data.shape[:2]
2572            xbar = x.zeros_like()
2573
2574        else:
2575            xbar = out[0]
2576
2577        cls._pb_reshape(ybar.data, x.data, y.data, out = xbar.data)
2578        return xbar
2579
2580
2581    @classmethod
2582    def pb_inv(cls, ybar, x, y, out = None):
2583        if out is None:
2584            D,P = y.data.shape[:2]
2585            xbar = x.zeros_like()
2586
2587        else:
2588            xbar, = out
2589
2590        cls._inv_pullback(ybar.data, x.data, y.data, out = xbar.data)
2591        return xbar
2592
2593
2594    @classmethod
2595    def pb_solve(cls, ybar, A, x, y, out = None):
2596        D,P = y.data.shape[:2]
2597
2598
2599        if not isinstance(A, UTPM):
2600            raise NotImplementedError('should implement that')
2601
2602        if not isinstance(x, UTPM):
2603
2604            tmp = x
2605            x = UTPM(numpy.zeros( (D,P) + x.shape))
2606            for p in range(P):
2607                x.data[0,p] = tmp[...]
2608
2609        if out is None:
2610            xbar = x.zeros_like()
2611            Abar = A.zeros_like()
2612
2613        else:
2614            if out[1] is None:
2615                Abar = out[0]
2616                xbar = x.zeros_like()
2617
2618            else:
2619                Abar, xbar = out
2620
2621        cls._solve_pullback(ybar.data, A.data, x.data, y.data, out = (Abar.data, xbar.data))
2622
2623        return Abar, xbar
2624
2625    @classmethod
2626    def pb_trace(cls, ybar, x, y, out = None):
2627        if out is None:
2628            out = (x.zeros_like(),)
2629
2630        xbar, = out
2631        Nx = xbar.shape[0]
2632        for nx in range(Nx):
2633            xbar[nx,nx] += ybar
2634
2635        return xbar
2636
2637    @classmethod
2638    def pb_transpose(cls, ybar, x, y, out = None):
2639        if out is None:
2640            raise NotImplementedError('should implement that')
2641
2642        xbar, = out
2643        xbar = cls.transpose(ybar)
2644        return xbar
2645
2646    @classmethod
2647    def pb_conjugate(cls, ybar, x, y, out = None):
2648        if out is None:
2649            raise NotImplementedError('should implement that')
2650
2651        xbar, = out
2652        xbar += cls.conjugate(ybar)
2653        return xbar
2654
2655
2656
2657    @classmethod
2658    def qr(cls, A, out = None, work = None, epsilon = 1e-14):
2659        D,P,M,N = numpy.shape(A.data)
2660        K = min(M,N)
2661
2662        if out is None:
2663            Q = cls(cls.__zeros__((D,P,M,K), dtype=A.data.dtype))
2664            R = cls(cls.__zeros__((D,P,K,N), dtype=A.data.dtype))
2665
2666        else:
2667            Q,R = out
2668
2669        UTPM._qr(A.data, out = (Q.data, R.data), epsilon = epsilon)
2670        return Q,R
2671
2672    @classmethod
2673    def pb_qr(cls, Qbar, Rbar, A, Q, R, out = None):
2674        D,P,M,N = numpy.shape(A.data)
2675
2676        if out is None:
2677            Abar = A.zeros_like()
2678
2679        else:
2680            Abar, = out
2681
2682        UTPM._qr_pullback( Qbar.data, Rbar.data, A.data, Q.data, R.data, out = Abar.data)
2683        return Abar
2684
2685    @classmethod
2686    def qr_full(cls, A, out = None, work = None):
2687        D,P,M,N = numpy.shape(A.data)
2688
2689        if out is None:
2690            Q = cls(cls.__zeros__((D,P,M,M), dtype=A.data.dtype))
2691            R = cls(cls.__zeros__((D,P,M,N), dtype=A.data.dtype))
2692
2693        else:
2694            Q,R = out
2695
2696        UTPM._qr_full(A.data, out = (Q.data, R.data))
2697
2698        return Q,R
2699
2700    @classmethod
2701    def pb_qr_full(cls, Qbar, Rbar, A, Q, R, out = None):
2702        D,P,M,N = numpy.shape(A.data)
2703
2704        if out is None:
2705            Abar = A.zeros_like()
2706
2707        else:
2708            Abar, = out
2709
2710        UTPM._qr_full_pullback( Qbar.data, Rbar.data, A.data, Q.data, R.data, out = Abar.data)
2711        return Abar
2712
2713
2714    @classmethod
2715    def eigh(cls, A, out = None, epsilon = 1e-8):
2716        """
2717        computes the eigenvalue decomposition A = Q^T L Q
2718        of a symmetrical matrix A with distinct eigenvalues
2719
2720        (l,Q) = UTPM.eigh(A, out=None)
2721
2722        """
2723
2724        D,P,M,N = numpy.shape(A.data)
2725
2726        if out is None:
2727            l = cls(cls.__zeros__((D,P,N), dtype=A.data.dtype))
2728            Q = cls(cls.__zeros__((D,P,N,N), dtype=A.data.dtype))
2729
2730        else:
2731            l,Q = out
2732
2733        UTPM._eigh( l.data, Q.data, A.data, epsilon = epsilon)
2734
2735        return l,Q
2736
2737    @classmethod
2738    def eigh1(cls, A, out = None, epsilon = 1e-8):
2739        """
2740        computes the relaxed eigenvalue decompositin of level 1
2741        of a symmetrical matrix A with distinct eigenvalues
2742
2743        (L,Q,b) = UTPM.eig1(A)
2744
2745        """
2746
2747        D,P,M,N = numpy.shape(A.data)
2748
2749        if out is None:
2750            L = cls(cls.__zeros__((D,P,N,N), dtype=A.data.dtype))
2751            Q = cls(cls.__zeros__((D,P,N,N), dtype=A.data.dtype))
2752
2753        else:
2754            L,Q = out
2755
2756        b_list = []
2757        for p in range(P):
2758            b = UTPM._eigh1( L.data[:,p], Q.data[:,p], A.data[:,p], epsilon = epsilon)
2759            b_list.append(b)
2760
2761        return L,Q,b_list
2762
2763    @classmethod
2764    def pb_eigh(cls, lbar, Qbar,  A, l, Q,  out = None):
2765        D,P,M,N = numpy.shape(A.data)
2766
2767        if out is None:
2768            Abar = A.zeros_like()
2769
2770        else:
2771            Abar, = out
2772
2773        UTPM._eigh_pullback( lbar.data,  Qbar.data, A.data,  l.data, Q.data, out = Abar.data)
2774        return Abar
2775
2776    @classmethod
2777    def pb_eigh1(cls, Lbar, Qbar, bbar_list, A, L, Q, b_list,  out = None):
2778        D,P,M,N = numpy.shape(A.data)
2779
2780        if out is None:
2781            Abar = A.zeros_like()
2782
2783        else:
2784            Abar, = out
2785
2786        UTPM._eigh1_pullback( Lbar.data,  Qbar.data, A.data,  L.data, Q.data, b_list, out = Abar.data)
2787        return Abar
2788
2789
2790    @classmethod
2791    def eig(cls, A, out = None):
2792        """
2793        computes the eigenvalue decomposition Q^-1 A Q = L
2794        of a diagonalizable matrix A with distinct eigenvalues
2795
2796        (l,Q) = UTPM.eig(A, out=None)
2797
2798        """
2799
2800        D,P,M,N = numpy.shape(A.data)
2801
2802        assert M == N, 'A must be a square matrix, but A.shape = (%d, %d)!'%A.shape
2803
2804        assert D <= 2, 'sorry: only first-order Taylor polynomials are supported right now'
2805
2806        if out is None:
2807            l = cls(cls.__zeros__((D,P,N), dtype='complex'))
2808            Q = cls(cls.__zeros__((D,P,N,N), dtype='complex'))
2809
2810        else:
2811            l,Q = out
2812
2813        for p in range(P):
2814
2815            # d=0: nominal computation
2816            t1, t2 = numpy.linalg.eig(A.data[0,p])
2817
2818            l.data[0,p], Q.data[0,p] = t1, t2
2819
2820            if D == 2:
2821                # d=1: first-order coefficient
2822                v1 = numpy.linalg.solve(Q.data[0,p], A.data[1,p])
2823                v2 = numpy.dot(v1, Q.data[0,p])
2824                l.data[1,p] = numpy.diag(v2)
2825
2826                F = numpy.zeros((M,M), dtype=l.data.dtype)
2827                for i in range(M):
2828                    F[i, :] -= l.data[0,p,i]
2829
2830                for j in range(M):
2831                    F[:, j] += l.data[0,p,j]
2832                    F[j, j] = numpy.infty
2833
2834                F = 1./F
2835
2836                Q.data[1,p] = numpy.dot(Q.data[0,p], F * v2)
2837
2838        if numpy.allclose(0, l.data.imag) and numpy.allclose(0, Q.data.imag):
2839            l = cls(l.data.real.astype(float))
2840            Q = cls(Q.data.real.astype(float))
2841
2842        return l, Q
2843
2844    @classmethod
2845    def pb_eig(cls, lbar, Qbar,  A, l, Q,  out = None):
2846        D,P,M,N = numpy.shape(A.data)
2847
2848        if out is None:
2849            Abar = A.zeros_like()
2850
2851        else:
2852            Abar, = out
2853
2854        E = Q.zeros_like()
2855        for i in range(M):
2856            E[i, :] -= l[i]
2857
2858        for j in range(M):
2859            E[:, j] += l[j]
2860            E[j, j] = numpy.infty
2861
2862        F = 1./E
2863        Lbar = UTPM.diag(lbar)
2864        v1 = Lbar + F * UTPM.dot(Q.T, Qbar)
2865        v2 = UTPM.dot(v1, Q.T)
2866        v3 = UTPM.solve(Q.T, v2)
2867        Abar += v3
2868
2869        return Abar
2870
2871
2872    @classmethod
2873    def svd(cls, A, out = None, epsilon = 1e-8):
2874        """
2875        computes the singular value decomposition A = U S V.T
2876        of matrices A with full rank (i.e. nonzero singular values)
2877        by reformulation to eigh.
2878
2879        (U, S, VT) = UTPM.svd(A, epsilon= 1e-8)
2880
2881        Parameters
2882        ----------
2883
2884        A: array_like
2885            input array (numpy.ndarray, algopy.UTPM or algopy.Function instance)
2886
2887        epsilon:   float
2888            threshold to evaluate the rank of A
2889
2890        Implementation
2891        --------------
2892
2893        The singular value decomposition is directly related to the symmetric
2894        eigenvalue decomposition.
2895
2896        See for Reference
2897
2898        * Bunse-Gerstner et al., Numerical computation of an analytic singular value
2899          decomposition of a matrix valued function
2900
2901        * A. Bjoerk, Numerical Methods for Least Squares Problems, SIAM, 1996
2902        for the relation between SVD and symm. eigenvalue decomposition
2903
2904        * S. F. Walter, Structured Higher-Order Algorithmic Differentiation
2905        in the Forward and Reverse Mode with Application in Optimum Experimental
2906        Design, PhD thesis, 2011
2907        for the Taylor polynomial arithmetic.
2908
2909        """
2910
2911        D,P,M,N = numpy.shape(A.data)
2912        K = min(M,N)
2913
2914        if out is None:
2915            U = cls(cls.__zeros__((D,P,M,M), dtype=A.data.dtype))
2916            s = cls(cls.__zeros__((D,P,K), dtype=A.data.dtype))
2917            V = cls(cls.__zeros__((D,P,N,N), dtype=A.data.dtype))
2918
2919        # real symmetric eigenvalue decomposition
2920
2921        B = cls(cls.__zeros__((D,P, M+N, M+N), dtype=A.data.dtype))
2922        B[:M,M:] = A
2923        B[M:,:M] = A.T
2924        l,Q = cls.eigh(B, epsilon=epsilon)
2925
2926
2927        # compute the rank
2928        # FIXME: this compound algorithm should be generic, i.e., also be applicable
2929        #        in the reverse mode. Need to replace *.data accesses
2930        r = 0
2931
2932        for i in range(K):
2933            if numpy.any(abs(l[i].data) > epsilon):
2934                r = i+1
2935
2936        # resort eigenvalues from large to small
2937        # and update l and Q accordingly
2938        tmp = numpy.arange(M+N)[::-1]
2939        Pr = numpy.eye(M+N)
2940        Pr = Pr[tmp]
2941        l = cls.dot(Pr, l)
2942        Q = cls.dot(Q, Pr.T)
2943
2944        # find U S V.T
2945        U = cls(cls.__zeros__((D,P, M,M), dtype=Q.data.dtype))
2946        V = cls(cls.__zeros__((D,P, N,N), dtype=Q.data.dtype))
2947        U[:,:r] = 2.**0.5*Q[:M,:r]
2948        # compute orthogonal columns to U[:, :r]
2949        U[:, r:] = cls.qr_full(U[:,:r])[0][:, r:]
2950        # U[:,r:] = Q[:M, 2*r: r+M]
2951        V[:,:r] = 2.**0.5*Q[M:,:r]
2952        # V[:,r:] = Q[M:,r+M:]
2953        V[:, r:] = cls.qr_full(V[:,:r])[0][:, r:]
2954        s[:] = l[:K]
2955
2956        return U, s, V
2957
2958    @classmethod
2959    def pb_svd(cls, Ubar, sbar, Vbar,  A, U, s, V,  out = None):
2960        D,P,M,N = numpy.shape(A.data)
2961
2962        assert M <= N, "only M <= N is supported, please use the SVD of the transpose"
2963
2964        if out is None:
2965            Abar = A.zeros_like()
2966
2967        else:
2968            Abar, = out
2969
2970        Sbar = A.zeros_like()
2971
2972        for i in range(M):
2973            Sbar[i,i] = sbar[i]
2974
2975        Abar += UTPM.dot(U, UTPM.dot(Sbar, V.T))
2976
2977        F = U.zeros_like()
2978        for i in range(M):
2979            F[i, :] -= s[i]**2
2980
2981        for j in range(M):
2982            F[:, j] += s[j]**2
2983            F[j, j] = numpy.infty
2984
2985        F = 1./F
2986
2987
2988        B = F * UTPM.dot(U.T, Ubar)
2989        Dbar = UTPM.dot(V.T, Vbar)
2990
2991        D1bar, D2hat = Dbar[:M, :M], Dbar[:M, M:]
2992        D2til, D3bar = Dbar[M:, :M], Dbar[M:, M:]
2993
2994        D2bar = D2til.T - D2hat
2995
2996        G = F * D1bar
2997        Pbar = A.zeros_like()
2998        P1bar, P2bar = Pbar[:, :M], Pbar[:, M:]
2999
3000        P1bar[...] = s.reshape((M, 1)) * (G + G.T) + s.reshape((1, M)) * (B + B.T)
3001        P2bar[...] = D2bar/s.reshape((M, 1))
3002
3003        Abar += UTPM.dot(U, UTPM.dot(Pbar, V.T))
3004
3005        return Abar
3006
3007    @classmethod
3008    def diag(cls, v, k = 0, out = None):
3009        """Extract a diagonal or construct  diagonal UTPM instance"""
3010        return cls(cls._diag(v.data))
3011
3012    @classmethod
3013    def pb_diag(cls, ybar, x, y, k = 0, out = None):
3014        """Extract a diagonal or construct  diagonal UTPM instance"""
3015
3016        if out is None:
3017            xbar = x.zeros_like()
3018
3019        else:
3020            xbar, = out
3021
3022        return cls(cls._diag_pullback(ybar.data, x.data, y.data, k = k, out = xbar.data))
3023
3024
3025    @classmethod
3026    def symvec(cls, A, UPLO = 'F'):
3027        """
3028        maps a symmetric matrix to a vector containing the distinct elements
3029        """
3030        import algopy.utils
3031        return algopy.utils.symvec(A, UPLO=UPLO)
3032
3033    @classmethod
3034    def pb_symvec(cls, vbar, A, UPLO, v, out = None):
3035
3036        if out is None:
3037            Abar = A.zeros_like()
3038
3039        else:
3040            Abar = out[0]
3041
3042        N,M = A.shape
3043
3044        if UPLO=='F':
3045            count = 0
3046            for row in range(N):
3047                for col in range(row,N):
3048                    Abar[row,col] += 0.5 * vbar[count]
3049                    Abar[col,row] += 0.5 * vbar[count]
3050                    count +=1
3051
3052        elif UPLO=='L':
3053            count = 0
3054            for n in range(N):
3055                for m in range(n,N):
3056                    Abar[m,n] = vbar[count]
3057                    count +=1
3058
3059        elif UPLO=='U':
3060            count = 0
3061            for n in range(N):
3062                for m in range(n,N):
3063                    Abar[n,m] = vbar[count]
3064                    count +=1
3065
3066        else:
3067            err_str = "UPLO must be either 'F','L', or 'U'\n"
3068            err_str+= "however, provided UPLO=%s"%UPLO
3069            raise ValueError(err_str)
3070
3071        return Abar
3072
3073    @classmethod
3074    def vecsym(cls, v):
3075        """
3076        returns a full symmetric matrix filled
3077        the distinct elements of v, filled row-wise
3078        """
3079        D,P = v.data.shape[:2]
3080        Nv = v.data[0,0].size
3081
3082        tmp = numpy.sqrt(1 + 8*Nv)
3083        if abs(int(tmp) - tmp) > 1e-16:
3084            # hackish way to check that the input length of v makes sense
3085            raise ValueError('size of v does not match any possible symmetric matrix')
3086        N = (int(tmp) - 1)//2
3087        A = cls(numpy.zeros((D,P,N,N)))
3088
3089        count = 0
3090        for row in range(N):
3091            for col in range(row,N):
3092                A[row,col] = A[col,row] = v[count]
3093                count +=1
3094
3095        return A
3096
3097    @classmethod
3098    def pb_vecsym(cls, Abar, v, A, out = None):
3099
3100        if out is None:
3101            vbar = v.zeros_like()
3102
3103        else:
3104            vbar ,= out
3105
3106        N = A.shape[0]
3107
3108        count = 0
3109        for row in range(N):
3110            vbar[count] += Abar[row,row]
3111            count += 1
3112            for col in range(row+1,N):
3113                vbar[count] += Abar[col,row]
3114                vbar[count] += Abar[row,col]
3115                count +=1
3116
3117        return vbar
3118
3119
3120
3121    @classmethod
3122    def iouter(cls, x, y, out):
3123        cls._iouter(x.data, y.data, out.data)
3124        return out
3125
3126    def reshape(self,  newshape, order = 'C'):
3127        if order != 'C':
3128            raise NotImplementedError('should implement that')
3129        cls = self.__class__
3130        return cls(cls._reshape(self.data, newshape, order = order))
3131
3132
3133    @classmethod
3134    def combine_blocks(cls, in_X):
3135        """
3136        expects an array or list consisting of entries of type UTPM, e.g.
3137        in_X = [[UTPM1,UTPM2],[UTPM3,UTPM4]]
3138        and returns
3139        UTPM([[UTPM1.data,UTPM2.data],[UTPM3.data,UTPM4.data]])
3140
3141        """
3142
3143        in_X = numpy.array(in_X)
3144        Rb,Cb = numpy.shape(in_X)
3145
3146        # find the degree D and number of directions P
3147        D = 0; 	P = 0;
3148
3149        for r in range(Rb):
3150            for c in range(Cb):
3151                D = max(D, in_X[r,c].data.shape[0])
3152                P = max(P, in_X[r,c].data.shape[1])
3153
3154        # find the sizes of the blocks
3155        rows = []
3156        cols = []
3157        for r in range(Rb):
3158            rows.append(in_X[r,0].shape[0])
3159        for c in range(Cb):
3160            cols.append(in_X[0,c].shape[1])
3161        rowsums = numpy.array([ numpy.sum(rows[:r]) for r in range(0,Rb+1)],dtype=int)
3162        colsums = numpy.array([ numpy.sum(cols[:c]) for c in range(0,Cb+1)],dtype=int)
3163
3164        # create new matrix where the blocks will be copied into
3165        tc = numpy.zeros((D, P, rowsums[-1],colsums[-1]))
3166        for r in range(Rb):
3167            for c in range(Cb):
3168                tc[:,:,rowsums[r]:rowsums[r+1], colsums[c]:colsums[c+1]] = in_X[r,c].data[:,:,:,:]
3169
3170        return UTPM(tc)
3171
3172    @classmethod
3173    def tile(cls, A, reps, out = None):
3174        """UTPM implementation of numpy.tile(A, reps)"""
3175        D,P = A.data.shape[:2]
3176
3177        Bshp = numpy.tile(A.data[0,0], reps).shape
3178
3179        if out is None:
3180            B = cls( numpy.zeros( (D,P) + Bshp, dtype=A.dtype))
3181
3182        else:
3183            r, = out
3184
3185        for d in range(D):
3186            for p in range(P):
3187                B.data[d,p, ...] = numpy.tile(A.data[d,p], reps)
3188
3189        return B
3190
3191    @classmethod
3192    def pb_tile(cls, Bbar, A, reps, B, out = None):
3193
3194        if(isinstance(reps, int)):
3195            reps=[reps]
3196
3197        d = len(reps)
3198
3199        assert Bbar.shape == B.shape
3200
3201        if A.ndim < d:
3202            A2shp = [1]*(d - A.ndim) + list(A.shape)
3203            d2shp = list(reps)
3204        elif A.ndim > d:
3205            A2shp = list(A.shape)
3206            d2shp = [1]*(A.ndim - d) + list(reps)
3207        else:
3208            A2shp = list(A.shape)
3209            d2shp = list(reps)
3210
3211        if out is None:
3212            Abar = A.zeros_like()
3213
3214        else:
3215            Abar = out[0]
3216
3217        A2shp = numpy.array(A2shp)
3218
3219        # loop over all tiles and add tile to Abar
3220        # K = [7*3*2, 3*2, 2, 1]
3221        K = [int(numpy.prod(d2shp[::-1][:k])) for k in range(1+len(d2shp))][::-1]
3222        for i in range(K[0]):
3223            # convert i into multi-index
3224            m = numpy.array([(i/K[j+1]) % d2shp[j] for j in range(len(d2shp))])
3225            # create slice of B
3226            s = [slice(A2shp[k]*m[k], A2shp[k]*(m[k]+1)) for k in range(len(m))]
3227            Abar += Bbar[s]
3228
3229        return Abar
3230
3231    @classmethod
3232    def fft(cls, a, n=None, axis=-1, out=None):
3233        """UTPM equivalent to numpy.fft.fft(a, n=None, axis=-1)"""
3234        D,P = a.data.shape[:2]
3235
3236        if out is None:
3237            r = cls(numpy.zeros(a.data.shape, dtype=complex))
3238
3239        else:
3240            r, = out
3241
3242        for d in range(D):
3243            for p in range(P):
3244                r.data[d,p, ...] = numpy.fft.fft(a.data[d,p], n=n, axis=axis)
3245
3246        return r
3247
3248    @classmethod
3249    def pb_fft(cls, bbar, a, b, n=None, axis=-1, out=None):
3250        D,P = a.data.shape[:2]
3251
3252        if out is None:
3253            abar = cls(numpy.zeros(a.data.shape, dtype=complex))
3254
3255        else:
3256            abar, = out
3257
3258        for d in range(D):
3259            for p in range(P):
3260
3261                # abar.data[d,p, ...] += numpy.fft.fft(bbar.data[d,p], n=n, axis=axis)
3262                numpy.add(abar.data[d,p, ...], numpy.fft.fft(bbar.data[d,p], n=n, axis=axis), out=abar.data[d,p, ...], casting="unsafe")
3263
3264        return abar
3265
3266    @classmethod
3267    def ifft(cls, a, n=None, axis=-1, out=None):
3268        """UTPM equivalent to numpy.fft.ifft(a, n=None, axis=-1)"""
3269        D,P = a.data.shape[:2]
3270
3271        if out is None:
3272            r = cls(numpy.zeros(a.data.shape, dtype=complex))
3273
3274        else:
3275            r, = out
3276
3277        for d in range(D):
3278            for p in range(P):
3279                r.data[d,p, ...] = numpy.fft.ifft(a.data[d,p], n=n, axis=axis)
3280
3281        return r
3282
3283    @classmethod
3284    def pb_ifft(cls, bbar, a, b, n=None, axis=-1, out=None):
3285        D,P = a.data.shape[:2]
3286
3287        if out is None:
3288            abar = cls(numpy.zeros(a.data.shape, dtype=complex))
3289
3290        else:
3291            abar, = out
3292
3293        for d in range(D):
3294            for p in range(P):
3295                abar.data[d,p, ...] += numpy.fft.ifft(bbar.data[d,p], n=n, axis=axis)
3296
3297        return abar
3298
3299
3300
3301
3302class UTP(UTPM):
3303    """
3304    UTP(X, vectorized=False)
3305
3306    Univariate Taylor Polynomial (UTP)
3307    with coefficients that are arbitrary numpy.ndarrays
3308
3309    Attributes
3310    ----------
3311    data: numpy.ndarray
3312        underlying datastructure, a numpy.array of shape (D,P) + UTP.shape
3313
3314    coeff: numpy.ndarray like structure
3315        is accessed just like UTP.data but has the shape (D,) + UTP.shape if
3316        vectorized=False and exactly the same as UTP.data when vectorized=True
3317
3318    vectorized: bool
3319        whether the UTP is vectorized or not (default is False)
3320
3321    All other attributes are motivated from numpy.ndarray and return
3322    size, shape, ndim of an individual coefficient of the UTP. E.g.,
3323
3324    T: UTP
3325        Transpose of the UTP
3326    size: int
3327        Number of elements in a UTP coefficient
3328    shape: tuple of ints
3329        Shape of a UTP coefficient
3330    ndim: int
3331        The number of dimensions of a UTP coefficient
3332
3333    Parameters
3334    ----------
3335
3336    X: numpy.ndarray with shape (D, P, N1, N2, N3, ...) if vectorized=True
3337       otherwise a (D, N1, N2, N3, ...) array
3338
3339
3340    Remark:
3341        This class provides an improved userinterface compared to the class UTPM.
3342
3343        The difference is mainly the initialization.
3344
3345        E.g.::
3346
3347            x = UTP([1,2,3])
3348
3349        is equivalent to::
3350
3351            x = UTP([1,2,3], P=1)
3352            x = UTPM([[1],[2],[3]])
3353
3354        and::
3355            x = UTP([[1,2],[2,3],[3,4]], P=2)
3356
3357        is equivalent to::
3358
3359            x = UTPM([[1,2],[2,3],[3,4]])
3360    """
3361
3362    def __init__(self, X, vectorized=False):
3363        """
3364        see self.__class__.__doc__ for information
3365        """
3366        Ndim = numpy.ndim(X)
3367        self.vectorized = vectorized
3368        if Ndim >= 1:
3369            self.data = numpy.asarray(X)
3370            if vectorized == False:
3371                shp = self.data.shape
3372                self.data = self.data.reshape(shp[:1] + (1,) + shp[1:])
3373        else:
3374            raise NotImplementedError
3375
3376    @property
3377    def coeff(self, *args, **kwargs):
3378        if self.vectorized == False:
3379            return self.data[:,0,...]
3380        else:
3381            return self.data
3382
3383    def __str__(self):
3384        """ return string representation """
3385        return str(self.coeff)
3386