1"""
2This file contains the core algorithms for
3
4* the forward mode (univariate Taylor polynomial arithmetic)
5* the reverse mode
6
7The functions are operating solely on numpy datastructures.
8
9Rationale
10---------
11
12If speed is an issue, one can rather easily replace
13the function implementations by C or Fortran functions.
14
15"""
16
17import math
18import functools
19
20import numpy
21from numpy.lib.stride_tricks import as_strided, broadcast_arrays
22
23try:
24    import scipy.linalg
25    import scipy.special
26except ImportError:
27    pass
28
29try:
30    import pytpcore
31except ImportError:
32    pytpcore = None
33
34from algopy import nthderiv
35
36
37def _plus_const(x_data, c, out=None):
38    """
39    Constants are only added to the d=0 slice of the data array.
40    A function like this is not so useful for multiplication by a constant,
41    because UTPM multiplication by a constant scales the entire data array
42    rather than acting on only the d=0 slice.
43    """
44    if out is None:
45        y_data = numpy.copy(x_data)
46    else:
47        y_data = out
48    y_data[0] += c
49    return y_data
50
51def _eval_slow_generic(f, x_data, out=None):
52    """
53    This is related to summations associated with the name 'Faa di Bruno.'
54    @param f: f(X, out=None, n=0) computes nth derivative of f at X
55    @param x_data: something about algorithmic differentiation
56    @param out: something about algorithmic differentiation
57    @param return: something about algorithmic differentiation
58    """
59    #FIXME: Improve or replace this function.
60    # It is intended to help with naive implementations
61    # of truncated taylor expansions
62    # of functions of a low degree polynomial,
63    # when the nth derivatives of the function of interest
64    # can be computed more or less directly.
65
66    y_data = nthderiv.np_filled_like(x_data, 0, out=out)
67    D, P = x_data.shape[:2]
68
69    # base point: d = 0
70    y_data[0] = f(x_data[0])
71
72    # higher order coefficients: d > 0
73    for d in range(1, D):
74        # Accumulate coefficients of truncated expansions of powers
75        # of the polynomial.
76        if d == 1:
77            accum = x_data[1:].copy()
78        else:
79            for i in range(D-2, 0, -1):
80                accum[i] = numpy.sum(accum[:i] * x_data[i:0:-1], axis=0)
81            accum[0] = 0.
82        # Add the contribution of this summation term.
83        y_data[1:] += f(x_data[0], n=d) * accum / float(math.factorial(d))
84
85    return y_data
86
87def _black_f_white_fprime(f, fprime_data, x_data, out=None):
88    """
89    The function evaluation is a black box, but the derivative is compound.
90    @param f: computes the scalar function directly
91    @param fprime_data: the array associated with the evaluated derivative
92    @param x_data: something about algorithmic differentiation
93    @param out: something about algorithmic differentiation
94    @param return: something about algorithmic differentiation
95    """
96
97    y_data = nthderiv.np_filled_like(x_data, 0, out=out)
98    D, P = x_data.shape[:2]
99
100    # Do the direct computation efficiently (e.g. using C implemention of erf).
101    y_data[0] = f(x_data[0])
102
103    # Compute the truncated series coefficients using discrete convolution.
104    #FIXME: one of these two loops can be vectorized
105    for d in range(1, D):
106        for c in range(d):
107            y_data[d] += fprime_data[d-1-c] * x_data[c+1] * (c+1)
108        y_data[d] /= d
109
110    return y_data
111
112def _taylor_polynomials_of_ode_solutions(
113        a_data, b_data, c_data,
114        u_data, v_data,
115        ):
116    """
117    This is a general O(D^2) algorithm for functions that are ODE solutions.
118    It is an attempt to implement Proposition 13.1
119    of "Evaluating Derivatives" by Griewank and Walther (2008).
120    The function must satisfy the identity
121    b(u) f'(u) - a(u) f(u) = c(u)
122    where a, b and c are already represented by their Taylor expansions.
123    Also u is represented as a Taylor expansion, and so is v.
124    But we are only given the first term of v, which is the recursion base.
125    In this function we use the notation from the book mentioned above.
126    """
127
128    # define the number of terms allowed in the truncated series
129    D = u_data.shape[0]
130    d = D-1
131
132    # these arrays have elements that are scaled slightly differently
133    u_tilde_data = u_data.copy()
134    v_tilde_data = v_data.copy()
135    for j in range(1, D):
136        u_tilde_data[j] *= j
137        v_tilde_data[j] *= j
138
139    # this is just convenient temporary storage which is not so important
140    s = numpy.zeros_like(u_data)
141
142    # on the other hand the e_data is very important for recursion
143    e_data = numpy.zeros_like(u_data)
144
145    # do the dynamic programming to fill the v_data array
146    for k in range(D):
147        if k > 0:
148            for j in range(1, k+1):
149                s[k] += (c_data[k-j] + e_data[k-j]) * u_tilde_data[j]
150            for j in range(1, k):
151                s[k] -= b_data[k-j] * v_tilde_data[j]
152            v_tilde_data[k] = s[k] / b_data[0]
153            v_data[k] = v_tilde_data[k] / k
154        if k < d:
155            for j in range(k+1):
156                e_data[k] += a_data[j] * v_data[k-j]
157
158    return v_data
159
160
161def vdot(x,y, z = None):
162    """
163    vectorized dot
164
165    z = vdot(x,y)
166
167    Rationale:
168
169        given two iteratable containers (list,array,...) x and y
170        this function computes::
171
172            z[i] = numpy.dot(x[i],y[i])
173
174        if z is None, this function allocates the necessary memory
175
176    Warning: the naming is inconsistent with numpy.vdot
177    Warning: this is a preliminary version that is likely to be changed
178    """
179    x_shp = numpy.shape(x)
180    y_shp = numpy.shape(y)
181
182    if x_shp[-1] != y_shp[-2]:
183        raise ValueError('got x.shape = %s and y.shape = %s'%(str(x_shp),str(y_shp)))
184
185    if numpy.ndim(x) == 3:
186        P,N,M  = x_shp
187        P,M,K  = y_shp
188        retval = numpy.zeros((P,N,K))
189        for p in range(P):
190            retval[p,:,:] = numpy.dot(x[p,:,:], y[p,:,:])
191
192        return retval
193
194    elif numpy.ndim(x) == 4:
195        D,P,N,M  = x_shp
196        D,P,M,K  = y_shp
197        retval = numpy.zeros((D,P,N,K))
198        for d in range(D):
199            for p in range(P):
200                retval[d,p,:,:] = numpy.dot(x[d,p,:,:], y[d,p,:,:])
201
202        return retval
203
204def truncated_triple_dot(X,Y,Z, D):
205    """
206    computes d^D/dt^D ( [X]_D [Y]_D [Z]_D) with t set to zero after differentiation
207
208    X,Y,Z are (DT,P,N,M) arrays s.t. the dimensions match to compute dot(X[d,p,:,:], dot(Y[d,p,:,:], Z[d,p,:,:]))
209
210    """
211    import algopy.exact_interpolation
212    noP = False
213    if len(X.shape) == 3:
214        noP = True
215        DT,NX,MX = X.shape
216        X = X.reshape((DT,1,NX,MX))
217
218    if len(Y.shape) == 3:
219        noP = True
220        DT,NY,MY = Y.shape
221        Y = Y.reshape((DT,1,NY,MY))
222
223    if len(Z.shape) == 3:
224        noP = True
225        DT,NZ,MZ = Z.shape
226        Z = Z.reshape((DT,1,NZ,MZ))
227
228    DT,P,NX,MX = X.shape
229    DT,P,NZ,MZ = Z.shape
230
231    multi_indices = algopy.exact_interpolation.generate_multi_indices(3,D)
232    retval = numpy.zeros((P,NX,MZ))
233
234    for mi in multi_indices:
235        for p in range(P):
236            if mi[0] == D or mi[1] == D or mi[2] == D:
237                continue
238            retval[p] += numpy.dot(X[mi[0],p,:,:], numpy.dot(Y[mi[1],p,:,:], Z[mi[2],p,:,:]))
239
240    if noP == False:
241        return retval
242    else:
243        return retval[0]
244
245def broadcast_arrays_shape(x_shp,y_shp):
246
247    if len(x_shp) < len(y_shp):
248        tmp = x_shp
249        x_shp = y_shp
250        y_shp = tmp
251
252    z_shp = numpy.array(x_shp,dtype=int)
253    for l in range(1,len(y_shp)-1):
254        if z_shp[-l] == 1: z_shp[-l] = y_shp[-l]
255        elif z_shp[-l] != 1 and y_shp[-l] != 1 and z_shp[-l] != y_shp[-l]:
256            raise ValueError('cannot broadcast arrays')
257
258
259    return z_shp
260
261
262class RawAlgorithmsMixIn:
263
264    @classmethod
265    def _broadcast_arrays(cls, x_data, y_data):
266        """ UTPM equivalent of numpy.broadcast_arrays """
267
268        # transpose arrays s.t. numpy.broadcast can be used
269        Lx = len(x_data.shape)
270        Ly = len(y_data.shape)
271        x_data = x_data.transpose( tuple(range(2,Lx)) + (0,1))
272        y_data = y_data.transpose( tuple(range(2,Ly)) + (0,1))
273
274        # broadcast arrays
275        x_data, y_data = broadcast_arrays(x_data, y_data)
276
277
278        # transpose into the original format
279        Lx = len(x_data.shape)
280        Ly = len(y_data.shape)
281        x_data = x_data.transpose( (Lx-2, Lx-1) +  tuple(range(Lx-2)) )
282        y_data = y_data.transpose( (Ly-2, Ly-1) +  tuple(range(Lx-2)) )
283
284        return x_data, y_data
285
286    @classmethod
287    def _mul(cls, x_data, y_data, out=None):
288        """
289        z = x*y
290        """
291        if numpy.shape(x_data) != numpy.shape(y_data):
292            raise NotImplementedError
293        D, P = x_data.shape[:2]
294        #FIXME: there is a memoryview and buffer contiguity checking error
295        # which may or may not be caused by a bug in numpy or cython.
296        if pytpcore and all(s > 1 for s in x_data.shape):
297            # tp_mul is not careful about aliasing
298            z_data = numpy.empty_like(x_data)
299            x_data_reshaped = x_data.reshape((D, -1))
300            y_data_reshaped = y_data.reshape((D, -1))
301            z_data_reshaped = z_data.reshape((D, -1))
302            pytpcore.tp_mul(x_data_reshaped, y_data_reshaped, z_data_reshaped)
303            if out is not None:
304                out[...] = z_data_reshaped.reshape((z_data.shape))
305                return out
306            else:
307                return z_data
308        else:
309            # numpy.sum is careful about aliasing so we can use out=z_data
310            if out is None:
311                z_data = numpy.empty_like(x_data)
312            else:
313                z_data = out
314            for d in range(D)[::-1]:
315                numpy.sum(
316                        x_data[:d+1,:,...] * y_data[d::-1,:,...],
317                        axis=0,
318                        out = z_data[d,:,...])
319            return z_data
320
321
322    @classmethod
323    def _minimum(cls, x_data, y_data, out=None):
324        if x_data.shape != y_data.shape:
325            raise NotImplementedError(
326                    'algopy broadcasting is not implemented for this function')
327        D = x_data.shape[0]
328        xmask = numpy.less_equal(x_data[0], y_data[0])
329        ymask = 1 - xmask
330        z_data = numpy.empty_like(x_data)
331        for d in range(D):
332            numpy.add(xmask * x_data[d], ymask * y_data[d], out=z_data[d])
333        if out is not None:
334            out[...] = z_data[...]
335            return out
336        else:
337            return z_data
338
339    @classmethod
340    def _maximum(cls, x_data, y_data, out=None):
341        if x_data.shape != y_data.shape:
342            raise NotImplementedError(
343                    'algopy broadcasting is not implemented for this function')
344        D = x_data.shape[0]
345        xmask = numpy.greater_equal(x_data[0], y_data[0])
346        ymask = 1 - xmask
347        z_data = numpy.empty_like(x_data)
348        for d in range(D):
349            numpy.add(xmask * x_data[d], ymask * y_data[d], out=z_data[d])
350        if out is not None:
351            out[...] = z_data[...]
352            return out
353        else:
354            return z_data
355
356    @classmethod
357    def _amul(cls, x_data, y_data, out = None):
358        """
359        z += x*y
360        """
361        z_data = out
362        if out is None:
363            raise NotImplementedError
364
365        (D,P) = z_data.shape[:2]
366        for d in range(D):
367            z_data[d,:,...] +=  numpy.sum(x_data[:d+1,:,...] * y_data[d::-1,:,...], axis=0)
368
369    @classmethod
370    def _itruediv(cls, z_data, x_data):
371        (D,P) = z_data.shape[:2]
372        tmp_data = z_data.copy()
373        for d in range(D):
374            tmp_data[d,:,...] = 1./ x_data[0,:,...] * ( z_data[d,:,...] - numpy.sum(tmp_data[:d,:,...] * x_data[d:0:-1,:,...], axis=0))
375        z_data[...] = tmp_data[...]
376
377    @classmethod
378    def _truediv(cls, x_data, y_data, out = None):
379        """
380        z = x/y
381        """
382        if out is None:
383            raise NotImplementedError
384
385        z_data = numpy.empty_like(out)
386        (D,P) = z_data.shape[:2]
387        for d in range(D):
388            z_data[d,:,...] = 1./ y_data[0,:,...] * ( x_data[d,:,...] - numpy.sum(z_data[:d,:,...] * y_data[d:0:-1,:,...], axis=0))
389
390        out[...] = z_data[...]
391        return out
392
393    @classmethod
394    def _reciprocal(cls, y_data, out=None):
395        """
396        z = 1/y
397        """
398        #FIXME: this function could use some attention;
399        # it was copypasted from div
400        z_data = numpy.empty_like(y_data)
401        D = y_data.shape[0]
402        if pytpcore:
403            y_data_reshaped = y_data.reshape((D, -1))
404            z_data_reshaped = z_data.reshape((D, -1))
405            pytpcore.tp_reciprocal(y_data_reshaped, z_data_reshaped)
406        else:
407            for d in range(D):
408                if d == 0:
409                    z_data[d,:,...] = 1./ y_data[0,:,...] * ( 1 - numpy.sum(z_data[:d,:,...] * y_data[d:0:-1,:,...], axis=0))
410                else:
411                    z_data[d,:,...] = 1./ y_data[0,:,...] * ( 0 - numpy.sum(z_data[:d,:,...] * y_data[d:0:-1,:,...], axis=0))
412
413        if out is not None:
414            out[...] = z_data[...]
415            return out
416        else:
417            return z_data
418
419    @classmethod
420    def _pb_reciprocal(cls, ybar_data, x_data, y_data, out=None):
421        if out is None:
422            raise NotImplementedError('should implement that')
423        #FIXME: this is probably dumb
424        tmp = -cls._reciprocal(cls._square(x_data))
425        cls._amul(ybar_data, tmp, out=out)
426
427    @classmethod
428    def _floordiv(cls, x_data, y_data, out = None):
429        """
430        z = x // y
431
432        use L'Hospital's rule when leading coefficients of y_data are zero
433
434        """
435        z_data = out
436        if out is None:
437            raise NotImplementedError
438
439        (D,P) = z_data.shape[:2]
440
441        x_data = x_data.copy()
442        y_data = y_data.copy()
443
444        #print x_data
445        #print y_data
446
447
448        # left shifting x_data and y_data if necessary
449
450        mask = Ellipsis
451        while True:
452            mask = numpy.where( abs(y_data[0, mask]) <= 1e-8)
453
454            if len(mask[0]) == 0:
455                break
456            elif len(mask) == 1:
457                mask = mask[0]
458
459            x_data[:D-1, mask] = x_data[1:, mask]
460            x_data[D-1,  mask] = 0.
461
462            y_data[:D-1, mask] = y_data[1:, mask]
463            y_data[D-1,  mask] = 0.
464
465        for d in range(D):
466            z_data[d,:,...] = 1./ y_data[0,:,...] * \
467                         ( x_data[d,:,...]
468                           - numpy.sum(z_data[:d,:,...] * y_data[d:0:-1,:,...],
469                           axis=0)
470                         )
471
472    @classmethod
473    def _pow_real(cls, x_data, r, out = None):
474        """ y = x**r, where r is scalar """
475        y_data = out
476        if out is None:
477            raise NotImplementedError
478        (D,P) = y_data.shape[:2]
479
480        if type(r) == int and r >= 0:
481            if r == 0:
482                y_data[...] = 0.
483                y_data[0, ...] = 1.
484                return y_data
485
486            elif r == 1:
487                y_data[...] = x_data[...]
488                return y_data
489
490            elif r == 2:
491                return cls._square(x_data, out=y_data)
492
493            elif r >= 3:
494                y_data[...] = x_data[...]
495                for nr in range(r-1):
496                    cls._mul(x_data, y_data, y_data)
497                return
498
499            else:
500                raise NotImplementedError("power to %d is not implemented" % r)
501
502
503
504
505        y_data[0] = x_data[0]**r
506        for d in range(1,D):
507            y_data[d] = r * numpy.sum([y_data[d-k] * k * x_data[k] for k in range(1,d+1)], axis = 0) - \
508                numpy.sum([ x_data[d-k] * k * y_data[k] for k in range(1,d)], axis = 0)
509
510            y_data[d] /= x_data[0]
511            y_data[d] /= d
512
513    @classmethod
514    def _pb_pow_real(cls, ybar_data, x_data, r, y_data, out = None):
515        """ pullback function of y = pow(x,r) """
516        if out is None:
517            raise NotImplementedError('should implement that')
518
519        xbar_data = out
520        (D,P) = y_data.shape[:2]
521
522        # if r == 0:
523            # raise NotImplementedError('x**0 is special and has not been implemented')
524
525        # if type(r) == int:
526            # if r == 2:
527
528        # print 'r=',r
529        # print 'x_data=',x_data
530        # print 'y_data=',y_data
531        # print 'xbar_data=',xbar_data
532        # print 'ybar_data=',ybar_data
533
534        if type(r) == int:
535
536            if r > 0:
537
538                tmp = numpy.zeros_like(xbar_data)
539                cls._pow_real(x_data, r - 1, out = tmp)
540                tmp *= r
541                cls._mul(ybar_data, tmp, tmp)
542                xbar_data += tmp
543
544        else:
545
546            tmp = numpy.zeros_like(xbar_data)
547
548            cls._truediv(y_data, x_data, tmp)
549            tmp[...] = numpy.nan_to_num(tmp)
550            cls._mul(ybar_data, tmp, tmp)
551            tmp *= r
552
553            xbar_data += tmp
554
555        # print 'xbar_data=',xbar_data
556
557
558    @classmethod
559    def _max(cls, x_data, axis = None, out = None):
560
561        if out is None:
562            raise NotImplementedError('should implement that')
563
564        x_shp = x_data.shape
565
566        D,P = x_shp[:2]
567        shp = x_shp[2:]
568
569        if len(shp) > 1:
570            raise NotImplementedError('should implement that')
571
572        for p in range(P):
573            out[:,p] = x_data[:,p,numpy.argmax(x_data[0,p])]
574
575
576    @classmethod
577    def _argmax(cls, a_data, axis = None):
578
579        if axis is not None:
580            raise NotImplementedError('should implement that')
581
582        a_shp = a_data.shape
583        D,P = a_shp[:2]
584        return numpy.argmax(a_data[0].reshape((P,numpy.prod(a_shp[2:]))), axis = 1)
585
586    @classmethod
587    def _absolute(cls, x_data, out=None):
588        """
589        z = |x|
590        """
591        if out is None:
592            z_data = numpy.empty_like(x_data)
593        else:
594            z_data = out
595        D = x_data.shape[0]
596        if D > 1:
597            x_data_sign = numpy.sign(x_data[0])
598        for d in range(D):
599            if d == 0:
600                numpy.absolute(x_data[d], out=z_data[d])
601            else:
602                numpy.multiply(x_data[d], x_data_sign, out=z_data[d])
603        return z_data
604
605    @classmethod
606    def _pb_absolute(cls, ybar_data, x_data, y_data, out = None):
607        if out is None:
608            raise NotImplementedError('should implement that')
609        fprime_data = numpy.empty_like(x_data)
610        D = x_data.shape[0]
611        for d in range(D):
612            if d == 0:
613                numpy.sign(x_data[d], out=fprime_data[d])
614            else:
615                fprime_data[d].fill(0)
616        cls._amul(ybar_data, fprime_data, out=out)
617
618    @classmethod
619    def _negative(cls, x_data, out=None):
620        """
621        z = -x
622        """
623        return numpy.multiply(x_data, -1, out=out)
624
625    @classmethod
626    def _pb_negative(cls, ybar_data, x_data, y_data, out = None):
627        if out is None:
628            raise NotImplementedError('should implement that')
629        fprime_data = numpy.empty_like(x_data)
630        fprime_data[0].fill(-1)
631        fprime_data[1:].fill(0)
632        cls._amul(ybar_data, fprime_data, out=out)
633
634    @classmethod
635    def _square(cls, x_data, out=None):
636        """
637        z = x*x
638        This can theoretically be twice as efficient as mul(x, x).
639        """
640        if out is None:
641            z_data = numpy.empty_like(x_data)
642        else:
643            z_data = out
644        tmp = numpy.zeros_like(x_data)
645        D, P = x_data.shape[:2]
646        for d in range(D):
647            d_half = (d+1) // 2
648            if d:
649                AB = x_data[:d_half, :, ...] * x_data[d:d-d_half:-1, :, ...]
650                numpy.sum(AB * 2, axis=0, out=tmp[d, :, ...])
651            if (d+1) % 2 == 1:
652                tmp[d, :, ...] += numpy.square(x_data[d_half, :, ...])
653        z_data[...] = tmp[...]
654        return z_data
655
656    @classmethod
657    def _pb_square(cls, ybar_data, x_data, y_data, out = None):
658        if out is None:
659            raise NotImplementedError('should implement that')
660        cls._amul(ybar_data, x_data*2, out=out)
661
662    @classmethod
663    def _sqrt(cls, x_data, out = None):
664        if out is None:
665            raise NotImplementedError('should implement that')
666        y_data = numpy.zeros_like(x_data)
667        D,P = x_data.shape[:2]
668
669        y_data[0] = numpy.sqrt(x_data[0])
670        for k in range(1,D):
671            y_data[k] = 1./(2.*y_data[0]) * ( x_data[k] - numpy.sum( y_data[1:k] * y_data[k-1:0:-1], axis=0))
672        out[...] = y_data[...]
673        return out
674
675    @classmethod
676    def _pb_sqrt(cls, ybar_data, x_data, y_data, out = None):
677        if out is None:
678            raise NotImplementedError('should implement that')
679
680        xbar_data = out
681        tmp = xbar_data.copy()
682        cls._truediv(ybar_data, y_data, tmp)
683        tmp /= 2.
684        xbar_data += tmp
685        return xbar_data
686
687    @classmethod
688    def _exp(cls, x_data, out=None):
689        if out is None:
690            y_data = numpy.empty_like(x_data)
691        else:
692            y_data = out
693        D,P = x_data.shape[:2]
694        if pytpcore:
695            x_data_reshaped = x_data.reshape((D, -1))
696            y_data_reshaped = y_data.reshape((D, -1))
697            tmp = numpy.empty_like(x_data_reshaped)
698            pytpcore.tp_exp(x_data_reshaped, tmp, y_data_reshaped)
699        else:
700            y_data[0] = numpy.exp(x_data[0])
701            xtctilde = x_data[1:].copy()
702            for d in range(1,D):
703                xtctilde[d-1] *= d
704            for d in range(1, D):
705                y_data[d] = numpy.sum(y_data[:d][::-1]*xtctilde[:d], axis=0)/d
706        return y_data
707
708    @classmethod
709    def _pb_exp(cls, ybar_data, x_data, y_data, out = None):
710        if out is None:
711            raise NotImplementedError('should implement that')
712
713        xbar_data = out
714        cls._amul(ybar_data, y_data, xbar_data)
715
716    @classmethod
717    def _expm1(cls, x_data, out=None):
718        fprime_data = cls._exp(x_data)
719        return _black_f_white_fprime(
720                nthderiv.expm1, fprime_data, x_data, out=out)
721
722    @classmethod
723    def _pb_expm1(cls, ybar_data, x_data, y_data, out = None):
724        if out is None:
725            raise NotImplementedError('should implement that')
726        fprime_data = cls._exp(x_data)
727        cls._amul(ybar_data, fprime_data, out=out)
728
729    @classmethod
730    def _logit(cls, x_data, out=None):
731        fprime_data = cls._reciprocal(x_data - cls._square(x_data))
732        return _black_f_white_fprime(
733                scipy.special.logit, fprime_data, x_data, out=out)
734
735    @classmethod
736    def _pb_logit(cls, ybar_data, x_data, y_data, out = None):
737        if out is None:
738            raise NotImplementedError('should implement that')
739        fprime_data = cls._reciprocal(x_data - cls._square(x_data))
740        cls._amul(ybar_data, fprime_data, out=out)
741
742    @classmethod
743    def _expit(cls, x_data, out=None):
744        b_data = cls._reciprocal(_plus_const(cls._exp(x_data), 1))
745        fprime_data = b_data - cls._square(b_data)
746        return _black_f_white_fprime(
747                scipy.special.expit, fprime_data, x_data, out=out)
748
749    @classmethod
750    def _pb_expit(cls, ybar_data, x_data, y_data, out = None):
751        if out is None:
752            raise NotImplementedError('should implement that')
753        b_data = cls._reciprocal(_plus_const(cls._exp(x_data), 1))
754        fprime_data = b_data - cls._square(b_data)
755        cls._amul(ybar_data, fprime_data, out=out)
756
757    @classmethod
758    def _sign(cls, x_data, out = None):
759        if out is None:
760            raise NotImplementedError('should implement that')
761        y_data = out
762        D, P = x_data.shape[:2]
763        y_data[0] = numpy.sign(x_data[0])
764        y_data[1:].fill(0)
765        return y_data
766
767    @classmethod
768    def _pb_sign(cls, ybar_data, x_data, y_data, out = None):
769        if out is None:
770            raise NotImplementedError('should implement that')
771        xbar_data = out
772        tmp = numpy.zeros_like(x_data)
773        cls._amul(ybar_data, tmp, xbar_data)
774
775    @classmethod
776    def _botched_clip(cls, a_min, a_max, x_data, out= None):
777        """
778        In this function the args are permuted w.r.t numpy.
779        """
780        if out is None:
781            raise NotImplementedError('should implement that')
782        y_data = out
783        D, P = x_data.shape[:2]
784        y_data[0] = numpy.clip(x_data[0], a_min, a_max)
785        mask = numpy.logical_and(
786                numpy.less_equal(x_data[0], a_max),
787                numpy.greater_equal(x_data[0], a_min))
788        for d in range(1, D):
789            y_data[d] *= mask
790        return y_data
791
792    @classmethod
793    def _pb_botched_clip(
794            cls, ybar_data, a_min, a_max, x_data, y_data, out=None):
795        """
796        In this function the args are permuted w.r.t numpy.
797        """
798        if out is None:
799            raise NotImplementedError('should implement that')
800        xbar_data = out
801        tmp = numpy.zeros_like(x_data)
802        numpy.multiply(
803                numpy.less_equal(x_data[0], a_max),
804                numpy.greater_equal(x_data[0], a_min),
805                out=tmp[0])
806        cls._amul(ybar_data, tmp, xbar_data)
807
808
809    @classmethod
810    def _log(cls, x_data, out = None):
811        if out is None:
812            raise NotImplementedError('should implement that')
813        y_data = numpy.empty_like(x_data)
814        D,P = x_data.shape[:2]
815
816        # base point: d = 0
817        y_data[0] = numpy.log(x_data[0])
818
819        # higher order coefficients: d > 0
820
821        for d in range(1,D):
822            y_data[d] =  (x_data[d]*d - numpy.sum(x_data[1:d][::-1] * y_data[1:d], axis=0))
823            y_data[d] /= x_data[0]
824
825        for d in range(1,D):
826            y_data[d] /= d
827
828        out[...] = y_data[...]
829        return out
830
831    @classmethod
832    def _pb_log(cls, ybar_data, x_data, y_data, out = None):
833        if out is None:
834            raise NotImplementedError('should implement that')
835        xbar_data = out
836        xbar_data += cls._truediv(ybar_data, x_data, numpy.empty_like(xbar_data))
837        return xbar_data
838
839    @classmethod
840    def _log1p(cls, x_data, out=None):
841        fprime_data = cls._reciprocal(_plus_const(x_data, 1))
842        return _black_f_white_fprime(
843                numpy.log1p, fprime_data, x_data, out=out)
844
845    @classmethod
846    def _pb_log1p(cls, ybar_data, x_data, y_data, out=None):
847        if out is None:
848            raise NotImplementedError('should implement that')
849        xbar_data = out
850        xbar_data += cls._truediv(
851                ybar_data, _plus_const(x_data, 1), numpy.empty_like(xbar_data))
852        return xbar_data
853
854    @classmethod
855    def _dawsn(cls, x_data, out=None):
856        if out is None:
857            v_data = numpy.empty_like(x_data)
858        else:
859            v_data = out
860
861        # construct the u and v arrays
862        u_data = x_data
863        v_data[0, ...] = scipy.special.dawsn(u_data[0])
864
865        # construct values like in Table (13.2) of "Evaluating Derivatives"
866        a_data = -2 * u_data.copy()
867        b_data = _plus_const(numpy.zeros_like(u_data), 1)
868        c_data = _plus_const(numpy.zeros_like(u_data), 1)
869
870        # fill the rest of the v_data
871        _taylor_polynomials_of_ode_solutions(
872            a_data, b_data, c_data,
873            u_data, v_data)
874
875        return v_data
876
877    @classmethod
878    def _pb_dawsn(cls, ybar_data, x_data, y_data, out=None):
879        if out is None:
880            raise NotImplementedError('should implement that')
881        fprime_data = _plus_const(-2*cls._mul(x_data, cls._dawsn(x_data)), 1)
882        cls._amul(ybar_data, fprime_data, out=out)
883
884    @classmethod
885    def _tansec2(cls, x_data, out = None):
886        """ computes tan and sec in Taylor arithmetic"""
887        if out is None:
888            raise NotImplementedError('should implement that')
889        y_data, z_data = out
890        D,P = x_data.shape[:2]
891
892        # base point: d = 0
893        y_data[0] = numpy.tan(x_data[0])
894        z_data[0] = 1./(numpy.cos(x_data[0])*numpy.cos(x_data[0]))
895
896        # higher order coefficients: d > 0
897        for d in range(1,D):
898            y_data[d] = numpy.sum([k*x_data[k] * z_data[d-k] for k in range(1,d+1)], axis = 0)/d
899            z_data[d] = 2.*numpy.sum([k*y_data[k] * y_data[d-k] for k in range(1,d+1)], axis = 0)/d
900
901        return y_data, z_data
902
903    @classmethod
904    def _pb_tansec(cls, ybar_data, zbar_data, x_data, y_data, z_data, out = None):
905        if out is None:
906            raise NotImplementedError('should implement that')
907
908        xbar_data = out
909        cls._mul(2*zbar_data, y_data, y_data)
910        y_data += ybar_data
911        cls._amul(y_data, z_data, xbar_data)
912
913
914    @classmethod
915    def _sincos(cls, x_data, out = None):
916        """ computes sin and cos in Taylor arithmetic"""
917        if out is None:
918            raise NotImplementedError('should implement that')
919        s_data,c_data = out
920        D,P = x_data.shape[:2]
921
922        # base point: d = 0
923        s_data[0] = numpy.sin(x_data[0])
924        c_data[0] = numpy.cos(x_data[0])
925
926        # higher order coefficients: d > 0
927        for d in range(1,D):
928            s_data[d] = numpy.sum([k*x_data[k] * c_data[d-k] for k in range(1,d+1)], axis = 0)/d
929            c_data[d] = numpy.sum([-k*x_data[k] * s_data[d-k] for k in range(1,d+1)], axis = 0)/d
930
931        return s_data, c_data
932
933    @classmethod
934    def _pb_sincos(cls, sbar_data, cbar_data, x_data, s_data, c_data, out = None):
935        if out is None:
936            raise NotImplementedError('should implement that')
937
938        xbar_data = out
939        cls._amul(sbar_data, c_data, xbar_data)
940        cls._amul(cbar_data, -s_data, xbar_data)
941
942    @classmethod
943    def _arcsin(cls, x_data, out = None):
944        if out is None:
945            raise NotImplementedError('should implement that')
946        y_data,z_data = out
947        D,P = x_data.shape[:2]
948
949        # base point: d = 0
950        y_data[0] = numpy.arcsin(x_data[0])
951        z_data[0] = numpy.cos(y_data[0])
952
953        # higher order coefficients: d > 0
954        for d in range(1,D):
955            y_data[d] = (d*x_data[d] - numpy.sum([k*y_data[k] * z_data[d-k] for k in range(1,d)], axis = 0))/(z_data[0]*d)
956            z_data[d] = -numpy.sum([k*y_data[k] * x_data[d-k] for k in range(1,d+1)], axis = 0)/d
957
958        return y_data, z_data
959
960    @classmethod
961    def _arccos(cls, x_data, out = None):
962        if out is None:
963            raise NotImplementedError('should implement that')
964        y_data,z_data = out
965        D,P = x_data.shape[:2]
966
967        # base point: d = 0
968        y_data[0] = numpy.arccos(x_data[0])
969        z_data[0] = -numpy.sin(y_data[0])
970
971        # higher order coefficients: d > 0
972        for d in range(1,D):
973            y_data[d] = (d*x_data[d] - numpy.sum([k*y_data[k] * z_data[d-k] for k in range(1,d)], axis = 0))/(z_data[0]*d)
974            z_data[d] = -numpy.sum([k*y_data[k] * x_data[d-k] for k in range(1,d+1)], axis = 0)/d
975
976        return y_data, z_data
977
978    @classmethod
979    def _arctan(cls, x_data, out = None):
980        if out is None:
981            raise NotImplementedError('should implement that')
982        y_data,z_data = out
983        D,P = x_data.shape[:2]
984
985        # base point: d = 0
986        y_data[0] = numpy.arctan(x_data[0])
987        z_data[0] = 1 + x_data[0] * x_data[0]
988
989        # higher order coefficients: d > 0
990        for d in range(1,D):
991            y_data[d] = (d*x_data[d] - numpy.sum([k*y_data[k] * z_data[d-k] for k in range(1,d)], axis = 0))/(z_data[0]*d)
992            z_data[d] = 2* numpy.sum([k*x_data[k] * x_data[d-k] for k in range(1,d+1)], axis = 0)/d
993
994        return y_data, z_data
995
996
997
998    @classmethod
999    def _sinhcosh(cls, x_data, out = None):
1000        if out is None:
1001            raise NotImplementedError('should implement that')
1002        s_data,c_data = out
1003        D,P = x_data.shape[:2]
1004
1005        # base point: d = 0
1006        s_data[0] = numpy.sinh(x_data[0])
1007        c_data[0] = numpy.cosh(x_data[0])
1008
1009        # higher order coefficients: d > 0
1010        for d in range(1,D):
1011            s_data[d] = (numpy.sum([k*x_data[k] * c_data[d-k] for k in range(1,d+1)], axis = 0))/d
1012            c_data[d] = (numpy.sum([k*x_data[k] * s_data[d-k] for k in range(1,d+1)], axis = 0))/d
1013
1014        return s_data, c_data
1015
1016    @classmethod
1017    def _tanhsech2(cls, x_data, out = None):
1018        if out is None:
1019            raise NotImplementedError('should implement that')
1020        y_data,z_data = out
1021        D,P = x_data.shape[:2]
1022
1023        # base point: d = 0
1024        y_data[0] = numpy.tanh(x_data[0])
1025        z_data[0] = 1-y_data[0]*y_data[0]
1026
1027        # higher order coefficients: d > 0
1028        for d in range(1,D):
1029            y_data[d] = (numpy.sum([k*x_data[k] * z_data[d-k] for k in range(1,d+1)], axis = 0))/d
1030            z_data[d] = -2*(numpy.sum([k*y_data[k] * y_data[d-k] for k in range(1,d+1)], axis = 0))/d
1031
1032        return y_data, z_data
1033
1034    @classmethod
1035    def _erf(cls, x_data, out=None):
1036        fprime_data = (2. / math.sqrt(math.pi)) * cls._exp(-cls._square(x_data))
1037        return _black_f_white_fprime(
1038                nthderiv.erf, fprime_data, x_data, out=out)
1039
1040    @classmethod
1041    def _pb_erf(cls, ybar_data, x_data, y_data, out = None):
1042        if out is None:
1043            raise NotImplementedError('should implement that')
1044        fprime_data = (2. / math.sqrt(math.pi)) * cls._exp(-cls._square(x_data))
1045        cls._amul(ybar_data, fprime_data, out=out)
1046
1047    @classmethod
1048    def _erfi(cls, x_data, out=None):
1049        fprime_data = (2. / math.sqrt(math.pi)) * cls._exp(cls._square(x_data))
1050        return _black_f_white_fprime(
1051                nthderiv.erfi, fprime_data, x_data, out=out)
1052
1053    @classmethod
1054    def _pb_erfi(cls, ybar_data, x_data, y_data, out = None):
1055        if out is None:
1056            raise NotImplementedError('should implement that')
1057        fprime_data = (2. / math.sqrt(math.pi)) * cls._exp(cls._square(x_data))
1058        cls._amul(ybar_data, fprime_data, out=out)
1059
1060    @classmethod
1061    def _dpm_hyp1f1(cls, a, b, x_data, out=None):
1062        f = functools.partial(nthderiv.mpmath_hyp1f1, a, b)
1063        return _eval_slow_generic(f, x_data, out=out)
1064
1065    @classmethod
1066    def _pb_dpm_hyp1f1(cls, ybar_data, a, b, x_data, y_data, out=None):
1067        if out is None:
1068            raise NotImplementedError('should implement that')
1069        tmp = cls._dpm_hyp1f1(a+1., b+1., x_data) * (float(a) / float(b))
1070        cls._amul(ybar_data, tmp, out=out)
1071
1072    @classmethod
1073    def _hyp1f1(cls, a, b, x_data, out=None):
1074        f = functools.partial(nthderiv.hyp1f1, a, b)
1075        return _eval_slow_generic(f, x_data, out=out)
1076
1077    @classmethod
1078    def _pb_hyp1f1(cls, ybar_data, a, b, x_data, y_data, out=None):
1079        if out is None:
1080            raise NotImplementedError('should implement that')
1081        tmp = cls._hyp1f1(a+1., b+1., x_data) * (float(a) / float(b))
1082        cls._amul(ybar_data, tmp, out=out)
1083
1084    @classmethod
1085    def _hyperu(cls, a, b, x_data, out=None):
1086        f = functools.partial(nthderiv.hyperu, a, b)
1087        return _eval_slow_generic(f, x_data, out=out)
1088
1089    @classmethod
1090    def _pb_hyperu(cls, ybar_data, a, b, x_data, y_data, out=None):
1091        if out is None:
1092            raise NotImplementedError('should implement that')
1093        tmp = cls._hyperu(a+1., b+1., x_data) * (-a)
1094        cls._amul(ybar_data, tmp, out=out)
1095
1096    @classmethod
1097    def _dpm_hyp2f0(cls, a1, a2, x_data, out=None):
1098        f = functools.partial(nthderiv.mpmath_hyp2f0, a1, a2)
1099        return _eval_slow_generic(f, x_data, out=out)
1100
1101    @classmethod
1102    def _pb_dpm_hyp2f0(cls, ybar_data, a1, a2, x_data, y_data, out=None):
1103        if out is None:
1104            raise NotImplementedError('should implement that')
1105        tmp = cls._dpm_hyp2f0(a1+1., a2+1., x_data) * float(a1) * float(a2)
1106        cls._amul(ybar_data, tmp, out=out)
1107
1108    @classmethod
1109    def _hyp2f0(cls, a1, a2, x_data, out=None):
1110        f = functools.partial(nthderiv.hyp2f0, a1, a2)
1111        return _eval_slow_generic(f, x_data, out=out)
1112
1113    @classmethod
1114    def _pb_hyp2f0(cls, ybar_data, a1, a2, x_data, y_data, out=None):
1115        if out is None:
1116            raise NotImplementedError('should implement that')
1117        tmp = cls._hyp2f0(a1+1., a2+1., x_data) * float(a1) * float(a2)
1118        cls._amul(ybar_data, tmp, out=out)
1119
1120    @classmethod
1121    def _hyp0f1(cls, b, x_data, out=None):
1122        f = functools.partial(nthderiv.hyp0f1, b)
1123        return _eval_slow_generic(f, x_data, out=out)
1124
1125    @classmethod
1126    def _pb_hyp0f1(cls, ybar_data, b, x_data, y_data, out=None):
1127        if out is None:
1128            raise NotImplementedError('should implement that')
1129        tmp = cls._hyp0f1(b+1., x_data) / float(b)
1130        cls._amul(ybar_data, tmp, out=out)
1131
1132    @classmethod
1133    def _polygamma(cls, m, x_data, out=None):
1134        f = functools.partial(nthderiv.polygamma, m)
1135        return _eval_slow_generic(f, x_data, out=out)
1136
1137    @classmethod
1138    def _pb_polygamma(cls, ybar_data, m, x_data, y_data, out=None):
1139        if out is None:
1140            raise NotImplementedError('should implement that')
1141        tmp = cls._polygamma(m+1, x_data)
1142        cls._amul(ybar_data, tmp, out=out)
1143
1144    @classmethod
1145    def _psi(cls, x_data, out=None):
1146        if out is None:
1147            raise NotImplementedError('should implement that')
1148        return _eval_slow_generic(nthderiv.psi, x_data, out=out)
1149
1150    @classmethod
1151    def _pb_psi(cls, ybar_data, x_data, y_data, out=None):
1152        if out is None:
1153            raise NotImplementedError('should implement that')
1154        tmp = cls._polygamma(1, x_data)
1155        cls._amul(ybar_data, tmp, out=out)
1156
1157    @classmethod
1158    def _gammaln(cls, x_data, out=None):
1159        if out is None:
1160            raise NotImplementedError('should implement that')
1161        return _eval_slow_generic(nthderiv.gammaln, x_data, out=out)
1162
1163    @classmethod
1164    def _pb_gammaln(cls, ybar_data, x_data, y_data, out=None):
1165        if out is None:
1166            raise NotImplementedError('should implement that')
1167        tmp = cls._polygamma(0, x_data)
1168        cls._amul(ybar_data, tmp, out=out)
1169
1170
1171    @classmethod
1172    def _dot(cls, x_data, y_data, out = None):
1173        """
1174        z = dot(x,y)
1175        """
1176
1177        if out is None:
1178            new_shp = x_data.shape[:-1] + y_data.shape[2:-2] + (y_data.shape[-1],)
1179            out = numpy.zeros(new_shp, dtype=numpy.promote_types(x_data.dtype, y_data.dtype) )
1180
1181        z_data = out
1182        z_data[...] = 0.
1183
1184        D,P = x_data.shape[:2]
1185
1186        # print 'x_data.shape=', x_data.shape
1187        # print 'y_data.shape=', y_data.shape
1188        # print 'z_data.shape=', z_data.shape
1189
1190        for d in range(D):
1191            for p in range(P):
1192                for c in range(d+1):
1193                    tmp = numpy.dot(x_data[c,p,...],
1194                                    y_data[d-c,p,...])
1195                    numpy.add(z_data[d,p,...], tmp, out=z_data[d,p, ...], casting='unsafe')
1196
1197        return out
1198
1199    @classmethod
1200    def _dot_pullback(cls, zbar_data, x_data, y_data, z_data, out = None):
1201        if out is None:
1202            raise NotImplementedError('should implement that')
1203
1204        (xbar_data, ybar_data) = out
1205
1206        xbar_data += cls._dot(zbar_data, cls._transpose(y_data), out = xbar_data.copy())
1207        ybar_data += cls._dot(cls._transpose(x_data), zbar_data, out = ybar_data.copy())
1208
1209        return out
1210
1211    @classmethod
1212    def _dot_non_UTPM_y(cls, x_data, y_data, out = None):
1213        """
1214        z = dot(x,y)
1215        """
1216
1217        if out is None:
1218            raise NotImplementedError('should implement that')
1219
1220        z_data = out
1221        z_data[...] = 0.
1222
1223        D,P = x_data.shape[:2]
1224
1225        # print 'z_data=',z_data
1226
1227        for d in range(D):
1228            for p in range(P):
1229                z_data[d,p,...] = numpy.dot(x_data[d,p,...], y_data[...])
1230
1231        return out
1232
1233    @classmethod
1234    def _dot_non_UTPM_x(cls, x_data, y_data, out = None):
1235        """
1236        z = dot(x,y)
1237        """
1238
1239        if out is None:
1240            raise NotImplementedError('should implement that')
1241
1242        z_data = out
1243        z_data[...] = 0.
1244
1245        D,P = y_data.shape[:2]
1246
1247        for d in range(D):
1248            for p in range(P):
1249                z_data[d,p,...] = numpy.dot(x_data[...], y_data[d,p,...])
1250
1251        return out
1252
1253    @classmethod
1254    def _outer(cls, x_data, y_data, out = None):
1255        """
1256        z = outer(x,y)
1257        """
1258
1259        if out is None:
1260            raise NotImplementedError('should implement that')
1261
1262        z_data = out
1263        z_data[...] = 0.
1264
1265        D,P = x_data.shape[:2]
1266
1267        for d in range(D):
1268            for p in range(P):
1269                for c in range(d+1):
1270                    z_data[d,p,...] += numpy.outer(x_data[c,p,...], y_data[d-c,p,...])
1271
1272        return out
1273
1274    @classmethod
1275    def _outer_non_utpm_y(cls, x_data, y, out = None):
1276        """
1277        z = outer(x,y)
1278        where x is UTPM and y is ndarray
1279        """
1280
1281        if out is None:
1282            raise NotImplementedError('should implement that')
1283
1284        z_data = out
1285        z_data[...] = 0.
1286
1287        D,P = x_data.shape[:2]
1288
1289        for d in range(D):
1290            for p in range(P):
1291                z_data[d,p,...] += numpy.outer(x_data[d,p,...], y)
1292
1293        return out
1294
1295
1296    @classmethod
1297    def _outer_non_utpm_x(cls, x, y_data, out = None):
1298        """
1299        z = outer(x,y)
1300        where y is UTPM and x is ndarray
1301        """
1302
1303        if out is None:
1304            raise NotImplementedError('should implement that')
1305
1306        z_data = out
1307        z_data[...] = 0.
1308
1309        D,P = y_data.shape[:2]
1310
1311        for d in range(D):
1312            for p in range(P):
1313                z_data[d,p,...] += numpy.outer(x, y_data[d,p,...])
1314
1315        return out
1316
1317
1318
1319    @classmethod
1320    def _outer_pullback(cls, zbar_data, x_data, y_data, z_data, out = None):
1321        if out is None:
1322            raise NotImplementedError('should implement that')
1323
1324        (xbar_data, ybar_data) = out
1325
1326        xbar_data += cls._dot(zbar_data, y_data, out = xbar_data.copy())
1327        ybar_data += cls._dot(zbar_data, x_data, out = ybar_data.copy())
1328
1329        return out
1330
1331    @classmethod
1332    def _inv(cls, x_data, out = None):
1333        """
1334        computes y = inv(x)
1335        """
1336
1337        if out is None:
1338            raise NotImplementedError('should implement that')
1339
1340        y_data, = out
1341        (D,P,N,M) = y_data.shape
1342
1343        # tc[0] element
1344        for p in range(P):
1345            y_data[0,p,:,:] = numpy.linalg.inv(x_data[0,p,:,:])
1346
1347        # tc[d] elements
1348        for d in range(1,D):
1349            for p in range(P):
1350                for c in range(1,d+1):
1351                    y_data[d,p,:,:] += numpy.dot(x_data[c,p,:,:], y_data[d-c,p,:,:],)
1352                y_data[d,p,:,:] =  numpy.dot(-y_data[0,p,:,:], y_data[d,p,:,:],)
1353        return y_data
1354
1355
1356    @classmethod
1357    def _inv_pullback(cls, ybar_data, x_data, y_data, out = None):
1358        if out is None:
1359            raise NotImplementedError('should implement that')
1360
1361        xbar_data = out
1362        tmp1 = numpy.zeros(xbar_data.shape)
1363        tmp2 = numpy.zeros(xbar_data.shape)
1364
1365        tmp1 = cls._dot(ybar_data, cls._transpose(y_data), out = tmp1)
1366        tmp2 = cls._dot(cls._transpose(y_data), tmp1, out = tmp2)
1367
1368        xbar_data -= tmp2
1369        return out
1370
1371
1372    @classmethod
1373    def _solve_pullback(cls, ybar_data, A_data, x_data, y_data, out = None):
1374
1375        if out is None:
1376            raise NotImplementedError('should implement that')
1377
1378        Abar_data = out[0]
1379        xbar_data = out[1]
1380
1381        Tbar = numpy.zeros(xbar_data.shape)
1382
1383        cls._solve( A_data.transpose((0,1,3,2)), ybar_data, out = Tbar)
1384        Tbar *= -1.
1385        cls._iouter(Tbar, y_data, Abar_data)
1386        xbar_data -= Tbar
1387
1388        return out
1389
1390    @classmethod
1391    def _solve_non_UTPM_x_pullback(cls, ybar_data, A_data, x_data, y_data, out = None):
1392
1393        if out is None:
1394            raise NotImplementedError('should implement that')
1395
1396        Abar_data = out
1397
1398        Tbar = numpy.zeros(xbar_data.shape)
1399
1400        cls._solve( A_data.transpose((0,1,3,2)), ybar_data, out = Tbar)
1401        Tbar *= -1.
1402        cls._iouter(Tbar, y_data, Abar_data)
1403
1404        return out, None
1405
1406
1407    @classmethod
1408    def _solve(cls, A_data, x_data, out = None):
1409        """
1410        solves the linear system of equations for y::
1411
1412            A y = x
1413
1414        """
1415
1416        if out is None:
1417            raise NotImplementedError('should implement that')
1418
1419        y_data = out
1420
1421        x_shp = x_data.shape
1422        A_shp = A_data.shape
1423        D,P,M,N = A_shp
1424
1425        D,P,M,K = x_shp
1426
1427        # d = 0:  base point
1428        for p in range(P):
1429            y_data[0,p,...] = numpy.linalg.solve(A_data[0,p,...], x_data[0,p,...])
1430
1431        # d = 1,...,D-1
1432        dtype = numpy.promote_types(A_data.dtype, x_data.dtype)
1433        tmp = numpy.zeros((M,K),dtype=dtype)
1434        for d in range(1, D):
1435            for p in range(P):
1436                tmp[:,:] = x_data[d,p,:,:]
1437                for k in range(1,d+1):
1438                    tmp[:,:] -= numpy.dot(A_data[k,p,:,:],y_data[d-k,p,:,:])
1439                y_data[d,p,:,:] = numpy.linalg.solve(A_data[0,p,:,:],tmp)
1440
1441        return out
1442
1443
1444    @classmethod
1445    def _solve_non_UTPM_A(cls, A_data, x_data, out = None):
1446        """
1447        solves the linear system of equations for y::
1448
1449            A y = x
1450
1451        when A is a simple (N,N) float array
1452        """
1453
1454        if out is None:
1455            raise NotImplementedError('should implement that')
1456
1457        y_data = out
1458
1459        x_shp = numpy.shape(x_data)
1460        A_shp = numpy.shape(A_data)
1461        M,N = A_shp
1462        D,P,M,K = x_shp
1463
1464        assert M == N
1465
1466        for d in range(D):
1467            for p in range(P):
1468                y_data[d,p,...] = numpy.linalg.solve(A_data[:,:], x_data[d,p,...])
1469
1470        return out
1471
1472    @classmethod
1473    def _solve_non_UTPM_x(cls, A_data, x_data, out = None):
1474        """
1475        solves the linear system of equations for y::
1476
1477            A y = x
1478
1479        where x is simple (N,K) float array
1480        """
1481
1482        if out is None:
1483            raise NotImplementedError('should implement that')
1484
1485        y_data = out
1486
1487        x_shp = numpy.shape(x_data)
1488        A_shp = numpy.shape(A_data)
1489        D,P,M,N = A_shp
1490        M,K = x_shp
1491
1492        assert M==N
1493
1494        # d = 0:  base point
1495        for p in range(P):
1496            y_data[0,p,...] = numpy.linalg.solve(A_data[0,p,...], x_data[...])
1497
1498        # d = 1,...,D-1
1499        tmp = numpy.zeros((M,K),dtype=float)
1500        for d in range(1, D):
1501            for p in range(P):
1502                tmp[:,:] = 0.
1503                for k in range(1,d+1):
1504                    tmp[:,:] -= numpy.dot(A_data[k,p,:,:],y_data[d-k,p,:,:])
1505                y_data[d,p,:,:] = numpy.linalg.solve(A_data[0,p,:,:],tmp)
1506
1507
1508        return out
1509
1510    @classmethod
1511    def _cholesky(cls, A_data, L_data):
1512        """
1513        compute the choleksy decomposition in Taylor arithmetic of a symmetric
1514        positive definite matrix A
1515        i.e.
1516        ..math:
1517
1518            A = L L^T
1519        """
1520        DT,P,N = numpy.shape(A_data)[:3]
1521
1522        # allocate (temporary) projection matrix
1523        Proj = numpy.zeros((N,N))
1524        for r in range(N):
1525            for c in range(r+1):
1526                if r == c:
1527                    Proj[r,c] = 0.5
1528                else:
1529                    Proj[r,c] = 1
1530
1531        for p in range(P):
1532
1533            # base point: d = 0
1534            L_data[0,p] = numpy.linalg.cholesky(A_data[0,p])
1535
1536            # allocate temporary storage
1537            L0inv = numpy.linalg.inv(L_data[0,p])
1538            dF    = numpy.zeros((N,N),dtype=float)
1539
1540            # higher order coefficients: d > 0
1541            # STEP 1: compute diagonal elements of dL
1542            for D in range(1,DT):
1543                dF *= 0
1544                for d in range(1,D):
1545                    dF += numpy.dot(L_data[D-d,p], L_data[d,p].T)
1546
1547                # print numpy.dot(L_data[1,p],L_data[1,p].T)
1548                # print 'dF = ',dF
1549
1550                dF -= A_data[D,p]
1551
1552                dF = numpy.dot(numpy.dot(L0inv,dF),L0inv.T)
1553
1554                # compute off-diagonal entries
1555                L_data[D,p] = - numpy.dot( L_data[0,p], Proj * dF)
1556
1557                # compute diagonal entries
1558                tmp1 = numpy.diag(L_data[0,p])
1559                tmp2 = numpy.diag(dF)
1560                tmp3 = -0.5 * tmp1 * tmp2
1561                for n in range(N):
1562                    L_data[D,p,n,n] = tmp3[n]
1563
1564
1565    @classmethod
1566    def build_PL(cls, N):
1567        """
1568        build lower triangular matrix with all ones, i.e.
1569
1570        PL = [[0,0,0],
1571              [1,0,0],
1572              [1,1,0]]
1573        """
1574        return numpy.tril(numpy.ones((N,N)), -1)
1575
1576    @classmethod
1577    def build_PU(cls, N):
1578        """
1579        build upper triangular matrix with all ones, i.e.
1580
1581        PL = [[0,1,1],
1582              [0,0,1],
1583              [0,0,0]]
1584        """
1585        return numpy.triu(numpy.ones((N,N)), 1)
1586
1587
1588    @classmethod
1589    def _pb_cholesky(cls, Lbar_data, A_data, L_data, out = None):
1590        """
1591        pullback of the linear form of the cholesky decomposition
1592        """
1593
1594        if out is None:
1595            raise NotImplementedError('should implement this')
1596
1597        Abar_data = out
1598
1599        D,P,N = A_data.shape[:3]
1600
1601        # compute (P_L + 0.5*P_D) * dot(L.T, Lbar)
1602        Proj = cls.build_PL(N) + 0.5 * numpy.eye(N)
1603        tmp = cls._dot(cls._transpose(L_data), Lbar_data, cls.__zeros_like__(A_data))
1604        tmp *= Proj
1605
1606        # symmetrize (P_L + 0.5*P_D) * dot(L.T, Lbar)
1607        tmp = 0.5*(cls._transpose(tmp) + tmp)
1608
1609        # compute Abar
1610        Linv_data = cls._inv(L_data, (cls.__zeros_like__(A_data),))
1611        tmp2 = cls._dot(cls._transpose(Linv_data), tmp, cls.__zeros_like__(A_data))
1612        tmp3 = cls._dot(tmp2, Linv_data, cls.__zeros_like__(A_data))
1613        Abar_data += tmp3
1614
1615        return Abar_data
1616
1617
1618    @classmethod
1619    def _ndim(cls, a_data):
1620        return a_data[0,0].ndim
1621
1622    @classmethod
1623    def _shape(cls, a_data):
1624        return a_data[0,0].shape
1625
1626    @classmethod
1627    def _reshape(cls, a_data, newshape, order = 'C'):
1628
1629        if order != 'C':
1630            raise NotImplementedError('should implement that')
1631
1632        if isinstance(newshape,int):
1633            newshape = (newshape,)
1634
1635        return numpy.reshape(a_data, a_data.shape[:2] + newshape)
1636
1637    @classmethod
1638    def _pb_reshape(cls, ybar_data, x_data, y_data,  out=None):
1639        if out is None:
1640            raise NotImplementedError('should implement that')
1641
1642        return numpy.reshape(out, x_data.shape)
1643
1644    @classmethod
1645    def _iouter(cls, x_data, y_data, out_data):
1646        """
1647        computes dyadic product and adds it to out
1648        out += x y^T
1649        """
1650
1651        if len(cls._shape(x_data)) == 1:
1652            x_data = cls._reshape(x_data, cls._shape(x_data) + (1,))
1653
1654        if len(cls._shape(y_data)) == 1:
1655            y_data = cls._reshape(y_data, cls._shape(y_data) + (1,))
1656
1657        tmp = cls.__zeros__(out_data.shape, dtype = out_data.dtype)
1658        cls._dot(x_data, cls._transpose(y_data), out = tmp)
1659
1660        out_data += tmp
1661
1662        return out_data
1663
1664
1665
1666    @classmethod
1667    def __zeros_like__(cls, data):
1668        return numpy.zeros_like(data)
1669
1670    @classmethod
1671    def __zeros__(cls, shp, dtype):
1672        return numpy.zeros(shp, dtype = dtype)
1673
1674    @classmethod
1675    def _qr(cls,  A_data, out = None,  work = None, epsilon = 1e-14):
1676        """
1677        computes the qr decomposition (Q,R) = qr(A)    <===>    QR = A
1678
1679        INPUTS:
1680            A_data      (D,P,M,N) array             regular matrix
1681
1682        OUTPUTS:
1683            Q_data      (D,P,M,K) array             orthogonal vectors Q_1,...,Q_K
1684            R_data      (D,P,K,N) array             upper triagonal matrix
1685
1686            where K = min(M,N)
1687
1688        """
1689
1690        # check if the output array is provided
1691        if out is None:
1692            raise NotImplementedError('need to implement that...')
1693        Q_data = out[0]
1694        R_data = out[1]
1695
1696        DT,P,M,N = numpy.shape(A_data)
1697        K = min(M,N)
1698
1699        if M < N:
1700            A1_data = A_data[:,:,:,:M]
1701            A2_data = A_data[:,:,:,M:]
1702            R1_data = R_data[:,:,:,:M]
1703            R2_data = R_data[:,:,:,M:]
1704
1705            cls._qr_rectangular(A1_data, out = (Q_data, R1_data), epsilon = epsilon)
1706            # print 'QR1 - A1 = ', cls._dot(Q_data, R1_data, numpy.zeros_like(A_data[:,:,:,:M])) - A_data[:,:,:,:M]
1707            # print 'R2_data=',R2_data
1708            cls._dot(cls._transpose(Q_data), A2_data, out=R2_data)
1709            # print 'R2_data=',R2_data
1710
1711
1712        else:
1713            cls._qr_rectangular(A_data, out = (Q_data, R_data))
1714
1715    @classmethod
1716    def _qr_rectangular(cls,  A_data, out = None,  work = None, epsilon = 1e-14):
1717        """
1718        computation of qr(A) where A.shape(M,N) with M >= N
1719
1720        this function is called by the more general function _qr
1721        """
1722
1723
1724        DT,P,M,N = numpy.shape(A_data)
1725        K = min(M,N)
1726
1727        # check if the output array is provided
1728        if out is None:
1729            raise NotImplementedError('need to implement that...')
1730        Q_data = out[0]
1731        R_data = out[1]
1732
1733        # input checks
1734        if Q_data.shape != (DT,P,M,K):
1735            raise ValueError('expected Q_data.shape = %s but provided %s'%(str((DT,P,M,K)),str(Q_data.shape)))
1736        assert R_data.shape == (DT,P,K,N)
1737
1738        if not M >= N:
1739            raise NotImplementedError('A_data.shape = (DT,P,M,N) = %s but require (for now) that M>=N')
1740
1741
1742        # check if work arrays are provided, if not allocate them
1743        if work is None:
1744            dF = numpy.zeros((P,M,N))
1745            dG = numpy.zeros((P,K,K))
1746            X  = numpy.zeros((P,K,K))
1747            PL = numpy.array([[ r > c for c in range(N)] for r in range(K)],dtype=float)
1748            Rinv = numpy.zeros((P,K,N))
1749
1750        else:
1751            raise NotImplementedError('need to implement that...')
1752
1753
1754        # INIT: compute the base point
1755        for p in range(P):
1756            Q_data[0,p,:,:], R_data[0,p,:,:] = numpy.linalg.qr(A_data[0,p,:,:])
1757
1758
1759        for p in range(P):
1760            rank = 0
1761            for n in range(N):
1762                if abs(R_data[0,p,n,n]) > epsilon:
1763                    rank += 1
1764
1765            Rinv[p] = 0.
1766            if rank != 0:
1767                Rinv[p,:rank,:rank] = numpy.linalg.inv(R_data[0,p,:rank,:rank])
1768
1769        # ITERATE: compute the derivatives
1770        for D in range(1,DT):
1771            # STEP 1:
1772            dF[...] = 0.
1773            dG[...] = 0
1774            X[...]  = 0
1775
1776            for d in range(1,D):
1777                for p in range(P):
1778                    dF[p] += numpy.dot(Q_data[d,p,:,:], R_data[D-d,p,:,:])
1779                    dG[p] -= numpy.dot(Q_data[d,p,:,:].T, Q_data[D-d,p,:,:])
1780
1781            # STEP 2:
1782            H = A_data[D,:,:,:] - dF[:,:,:]
1783            S =  0.5 * dG
1784
1785            # STEP 3:
1786            for p in range(P):
1787                X[p,:,:] = PL * (numpy.dot( numpy.dot(Q_data[0,p,:,:].T, H[p,:,:,]), Rinv[p]) - S[p,:,:])
1788                X[p,:,:] = X[p,:,:] - X[p,:,:].T
1789
1790            # STEP 4:
1791            K = S + X
1792
1793            # STEP 5:
1794            for p in range(P):
1795                R_data[D,p,:,:] = numpy.dot(Q_data[0,p,:,:].T, H[p,:,:]) - numpy.dot(K[p,:,:],R_data[0,p,:,:])
1796
1797            # STEP 6:
1798            for p in range(P):
1799                if M == N:
1800                    Q_data[D,p,:,:] = numpy.dot(Q_data[0,p,:,:],K[p,:,:])
1801                else:
1802                    Q_data[D,p,:,:] = numpy.dot(H[p] - numpy.dot(Q_data[0,p],R_data[D,p]), Rinv[p])
1803
1804
1805    @classmethod
1806    def _qr_full(cls,  A_data, out = None,  work = None):
1807        """
1808        computation of QR = A
1809
1810        INPUTS:
1811            A    (M,N) UTPM instance            with A.data[0,:] have all rank N, M >= N
1812
1813        OUTPUTS:
1814            Q     (M,M) UTPM instance             orthonormal matrix
1815            R     (M,N) UTPM instance             only upper M rows are non-zero, i.e. R[:N,:] == 0
1816
1817
1818        """
1819
1820
1821
1822        D,P,M,N = numpy.shape(A_data)
1823
1824        # check if the output array is provided
1825        if out is None:
1826            raise NotImplementedError('need to implement that...')
1827        Q_data = out[0]
1828        R_data = out[1]
1829
1830        # input checks
1831        if Q_data.shape != (D,P,M,M):
1832            raise ValueError('expected Q_data.shape = %s but provided %s'%(str((DT,P,M,K)),str(Q_data.shape)))
1833        assert R_data.shape == (D,P,M,N)
1834
1835        if not M >= N:
1836            raise NotImplementedError('A_data.shape = (DT,P,M,N) = %s but require (for now) that M>=N')
1837
1838        # check if work arrays are provided, if not allocate them
1839        if work is None:
1840            dF = numpy.zeros((M,N))
1841            S = numpy.zeros((M,M))
1842            X  = numpy.zeros((M,M))
1843            PL = numpy.array([[ r > c for c in range(M)] for r in range(M)],dtype=float)
1844            Rinv = numpy.zeros((N,N))
1845            K  = numpy.zeros((M,M))
1846
1847        else:
1848            raise NotImplementedError('need to implement that...')
1849
1850        for p in range(P):
1851
1852            # d = 0: compute the base point
1853            Q_data[0,p,:,:], R_data[0,p,:,:] = scipy.linalg.qr(A_data[0,p,:,:])
1854
1855            # d > 0: iterate
1856            Rinv[:,:] = numpy.linalg.inv(R_data[0,p,:N,:])
1857
1858            for d in range(1,D):
1859                # STEP 1: compute dF and S
1860                dF[...] = A_data[d,p,:,:]
1861                S[...] = 0
1862
1863                for k in range(1,d):
1864                    dF[...] -= numpy.dot(Q_data[d-k,p,:,:], R_data[k,p,:,:])
1865                    S[...]  -= numpy.dot(Q_data[d-k,p,:,:].T, Q_data[k,p,:,:])
1866                S *= 0.5
1867
1868                # STEP 2: compute X
1869                X[...] = 0
1870                X[:,:N] = PL[:,:N] * (numpy.dot( numpy.dot(Q_data[0,p,:,:].T, dF[:,:]), Rinv) - S[:,:N])
1871                X[:,:] = X[:,:] - X[:,:].T
1872                K[...] = 0; K[...] += S;  K[...] += X
1873                R_data[d,p,:,:] = numpy.dot(Q_data[0,p,:,:].T, dF) - numpy.dot(K,R_data[0,p,:,:])
1874                Q_data[d,p,:,:] = numpy.dot(Q_data[0,p,:,:],K)
1875
1876
1877
1878    @classmethod
1879    def _qr_full_pullback(cls, Qbar_data, Rbar_data, A_data, Q_data, R_data, out = None):
1880        """
1881        computes the pullback of the qr decomposition (Q,R) = qr(A)    <===>    QR = A
1882
1883            A_data      (D,P,M,N) array             regular matrix
1884            Q_data      (D,P,M,M) array             orthogonal vectors Q_1,...,Q_K
1885            R_data      (D,P,M,N) array             upper triagonal matrix
1886
1887        """
1888
1889
1890        if out is None:
1891            raise NotImplementedError('need to implement that...')
1892
1893        Abar_data = out
1894        A_shp = A_data.shape
1895        D,P,M,N = A_shp
1896
1897        if M < N:
1898            raise NotImplementedError('supplied matrix has more columns that rows')
1899
1900        # STEP 1: compute: tmp1 = PL * ( Q.T Qbar - Qbar.T Q + R Rbar.T - Rbar R.T)
1901        PL = numpy.array([[ r > c for c in range(M)] for r in range(M)],dtype=float)
1902        tmp = cls._dot(cls._transpose(Q_data), Qbar_data) + cls._dot(R_data, cls._transpose(Rbar_data))
1903        tmp = tmp - cls._transpose(tmp)
1904
1905        for d in range(D):
1906            for p in range(P):
1907                tmp[d,p] *= PL
1908
1909        # STEP 2: compute H = K * R1^{-T}
1910        R1 = R_data[:,:,:N,:]
1911        K = tmp[:,:,:,:N]
1912        H = numpy.zeros((D,P,M,N))
1913
1914        cls._solve(R1, cls._transpose(K), out = cls._transpose(H))
1915
1916        H += Rbar_data
1917
1918        Abar_data += cls._dot(Q_data, H, out = numpy.zeros_like(Abar_data))
1919
1920        # tmp2 = cls._solve(cls._transpose(R_data[:,:,:N,:]), cls._transpose(tmp), out = numpy.zeros((D,P,M,N)))
1921        # tmp = cls._dot(tmp[:,:,:,:N], cls._transpose
1922        # print Rbar_data.shape
1923
1924
1925
1926
1927    @classmethod
1928    def _eigh(cls, L_data, Q_data, A_data, epsilon = 1e-8, full_output = False):
1929        """
1930        computes the eigenvalue decompositon
1931
1932        L,Q = eig(A)
1933
1934        for symmetric matrix A with possibly repeated eigenvalues, i.e.
1935        where L is a diagonal matrix of ordered eigenvalues l_1 >= l_2 >= ...>= l_N
1936        and Q a matrix of corresponding orthogonal eigenvectors
1937
1938        """
1939
1940        def lift_Q(Q, d, D):
1941            """ lift orthonormal matrix from degree d to degree D
1942            given [Q]_d = [Q_0,...,Q_{d-1]] s.t.  0 =_d [Q^T]_d [Q]_d - Id
1943            compute dQ s.t. [Q]_D = [[Q]_d , [dQ]_{D-d] satisfies 0 =_D [Q^T]_D [Q]_D - Id
1944            """
1945            S = numpy.zeros(Q.shape[1:])
1946            S = cls.__zeros_like__(Q[0])
1947            for k in range(d,D):
1948                S *= 0
1949                for i in range(1,k):
1950                    S += numpy.dot(Q[i,:,:].T, Q[k-i,:,:])
1951                Q[k] = -0.5 * numpy.dot(Q[0], S)
1952            return Q
1953
1954        # input checks
1955        DT,P,M,N = numpy.shape(A_data)
1956        assert M == N
1957        if Q_data.shape != (DT,P,N,N):
1958            raise ValueError('expected Q_data.shape = %s but provided %s'%(str((DT,P,M,K)),str(Q_data.shape)))
1959        if L_data.shape != (DT,P,N):
1960            raise ValueError('expected L_data.shape = %s but provided %s'%(str((DT,P,N)),str(L_data.shape)))
1961
1962
1963        for p in range(P):
1964            b = [0,N]
1965            L_tilde_data = A_data[:,p].copy()
1966            Q_data[0,p] = numpy.eye(N)
1967            for D in range(DT):
1968                # print 'relaxed problem of order d=',D+1
1969                # print 'b=',b
1970                tmp_b_list = []
1971                for nb in range(len(b)-1):
1972                    start, stop = b[nb], b[nb+1]
1973
1974                    # print 'stop-start=',stop-start
1975
1976                    Q_hat_data = numpy.zeros((DT-D, stop-start, stop-start), dtype = A_data.dtype)
1977                    L_hat_data = numpy.zeros((DT-D, stop-start, stop-start), dtype = A_data.dtype)
1978
1979
1980                    tmp_b = cls._eigh1(L_hat_data, Q_hat_data, L_tilde_data[D:, start:stop, start:stop], epsilon = epsilon)
1981                    tmp_b_list.append( tmp_b)
1982
1983                    # compute L_tilde
1984                    L_data[D:,p, start:stop] = numpy.diag(L_hat_data[0])
1985                    L_tilde_data[D:, start:stop, start:stop] = L_hat_data
1986
1987                    # update Q
1988                    # print 'Q_hat_data=',Q_hat_data
1989                    Q_tmp = numpy.zeros((DT, stop-start, stop-start), dtype = A_data.dtype)
1990                    Q_tmp[:DT-D] = Q_hat_data
1991
1992                    # print 'D,DT=',D,DT, DT-D
1993                    Q_tmp = lift_Q(Q_tmp, DT-D, DT)
1994
1995                    Q_tmp  = Q_tmp.reshape((DT,1,stop-start,stop-start))
1996                    # print 'Q_tmp=',Q_tmp
1997
1998                    # print Q_tmp.shape
1999                    Q_data[:,p:p+1,:,start:stop] = cls._dot(Q_data[:,p:p+1,:,start:stop],Q_tmp,numpy.zeros_like(Q_data[:,p:p+1,:,start:stop]))
2000
2001                    # print 'Q_data=',Q_data
2002
2003                # print tmp_b_list
2004                offset = 0
2005                for tmp_b in tmp_b_list:
2006                    # print 'tmp_b=',tmp_b + offset
2007                    b = numpy.union1d(b, tmp_b + offset)
2008                    offset += numpy.max(tmp_b)
2009                # print 'b=',b
2010
2011        # print Q_data
2012        # print L_data
2013
2014
2015    @classmethod
2016    def _eigh1(cls, L_data, Q_data, A_data, epsilon = 1e-8, full_output = False):
2017        """
2018        computes the solution of the relaxed problem of order 1
2019
2020        L,Q = eig(A)
2021
2022        for symmetric matrix A with possibly repeated eigenvalues, i.e.
2023        where L[0] is a diagonal matrix of ordered eigenvalues l_1 >= l_2 >= ...>= l_N
2024        and L[1:] is block diagonal.
2025
2026        and Q a matrix of corresponding orthonormal eigenvectors.
2027
2028        """
2029
2030        def find_repeated_values(L):
2031            """
2032            INPUT:  L    (N,) array of ordered values, dtype = float
2033            OUTPUT: b    (Nb,) array s.t. L[b[i:i+1]] are all repeated values
2034
2035            Nb is the number of blocks of repeated values. It holds that
2036            b[-1] = N.
2037
2038            e.g. L = [1.,1.,1.,2.,2.,3.,5.,7.,7.]
2039            then the output is [0,3,5,6,7,9]
2040            """
2041
2042            # print 'finding repeated eigenvalues'
2043            # print 'L = ',L
2044
2045
2046            N = len(L)
2047            # print 'L=',L
2048            b = [0]
2049            n = 0
2050            while n < N:
2051                m = n + 1
2052                while m < N:
2053                    # print 'n,m=',n,m
2054                    tmp = L[n] - L[m]
2055                    if numpy.abs(tmp) > epsilon:
2056                        b += [m]
2057                        break
2058                    m += 1
2059                n += (m - n)
2060            b += [N]
2061
2062            # print 'b=',b
2063            return numpy.asarray(b)
2064
2065        # input checks
2066        DT,M,N = numpy.shape(A_data)
2067        assert M == N
2068        if Q_data.shape != (DT,N,N):
2069            raise ValueError('expected Q_data.shape = %s but provided %s'%(str((DT,N,N)),str(Q_data.shape)))
2070        if L_data.shape != (DT,N,N):
2071            raise ValueError('expected L_data.shape = %s but provided %s'%(str((DT,N,N)),str(L_data.shape)))
2072
2073        # INIT: compute the base point
2074        tmp, Q_data[0,:,:] = numpy.linalg.eigh(A_data[0,:,:])
2075
2076        # set output L_data
2077        for n in range(N):
2078            L_data[0,n,n] = tmp[n]
2079
2080        # find repeated eigenvalues that define the block structure of the higher order coefficients
2081        b = find_repeated_values(tmp)
2082        # print 'b=',b
2083
2084        # compute H = 1/E
2085        H = numpy.zeros((N,N), dtype = A_data.dtype)
2086        for r in range(N):
2087            for c in range(N):
2088                tmp = L_data[0,c,c] - L_data[0,r,r]
2089                if abs(tmp) > epsilon:
2090                    H[r,c] = 1./tmp
2091        dG = numpy.zeros((N,N), dtype = A_data.dtype)
2092
2093        # ITERATE: compute derivatives
2094        for D in range(1,DT):
2095            dG[...] = 0.
2096
2097            # STEP 1:
2098            dF = truncated_triple_dot(Q_data.transpose(0,2,1), A_data, Q_data, D)
2099
2100            for d in range(1,D):
2101                dG += numpy.dot(Q_data[d].T, Q_data[D-d])
2102
2103            # STEP 2:
2104            S = -0.5 * dG
2105
2106            # STEP 3:
2107            K = dF + numpy.dot(numpy.dot(Q_data[0].T, A_data[D]),Q_data[0]) + numpy.dot(S, L_data[0]) + numpy.dot(L_data[0],S)
2108
2109            # STEP 4: compute L
2110            for nb in range(len(b)-1):
2111                start, stop = b[nb], b[nb+1]
2112                L_data[D,start:stop, start:stop] = K[start:stop, start:stop]
2113
2114            # STEP 5: compute Q
2115            XT = K*H
2116            Q_data[D] = numpy.dot(Q_data[0], XT + S)
2117
2118        return b
2119
2120
2121
2122    @classmethod
2123    def _mul_non_UTPM_x(cls, x_data, y_data, out = None):
2124        """
2125        z = x * y
2126        """
2127
2128        if out is None:
2129            raise NotImplementedError('need to implement that...')
2130        z_data = out
2131
2132        D,P = numpy.shape(y_data)[:2]
2133
2134        for d in range(D):
2135            for p in range(P):
2136                z_data[d,p] = x_data * y_data[d,p]
2137
2138    @classmethod
2139    def _eigh_pullback(cls, lambar_data, Qbar_data, A_data, lam_data, Q_data, out = None):
2140
2141        if out is None:
2142            raise NotImplementedError('need to implement that...')
2143
2144        Abar_data = out
2145
2146        A_shp = A_data.shape
2147        D,P,M,N = A_shp
2148
2149        assert M == N
2150
2151        # allocating temporary storage
2152        H = numpy.zeros(A_shp)
2153        tmp1 = numpy.zeros((D,P,N,N), dtype=float)
2154        tmp2 = numpy.zeros((D,P,N,N), dtype=float)
2155
2156        Id = numpy.zeros((D,P))
2157        Id[0,:] = 1
2158
2159        Lam_data    = cls._diag(lam_data)
2160        Lambar_data = cls._diag(lambar_data)
2161
2162        # STEP 1: compute H
2163        for m in range(N):
2164            for n in range(N):
2165                for p in range(P):
2166                    tmp = lam_data[0,p,n] - lam_data[0,p,m]
2167                    if numpy.abs(tmp) > 1e-8:
2168                        for d in range(D):
2169                            H[d,p,m,n] = 1./tmp
2170                # tmp = lam_data[:,:,n] -   lam_data[:,:,m]
2171                # cls._truediv(Id, tmp, out = H[:,:,m,n])
2172
2173        # STEP 2: compute Lbar +  H * Q^T Qbar
2174        cls._dot(cls._transpose(Q_data), Qbar_data, out = tmp1)
2175        tmp1[...] *= H[...]
2176        tmp1[...] += Lambar_data[...]
2177
2178        # STEP 3: compute Q ( Lbar +  H * Q^T Qbar ) Q^T
2179        cls._dot(Q_data, tmp1, out = tmp2)
2180        cls._dot(tmp2, cls._transpose(Q_data), out = tmp1)
2181
2182        Abar_data += tmp1
2183
2184        return out
2185
2186
2187    @classmethod
2188    def _eigh1_pullback(cls, Lambar_data, Qbar_data, A_data, Lam_data, Q_data, b_list, out = None):
2189
2190        if out is None:
2191            raise NotImplementedError('need to implement that...')
2192
2193        Abar_data = out
2194
2195        A_shp = A_data.shape
2196        D,P,M,N = A_shp
2197
2198
2199        E = numpy.zeros((P,N,N))
2200        tmp1 = numpy.zeros((D,P,N,N), dtype=float)
2201        tmp2 = numpy.zeros((D,P,N,N), dtype=float)
2202
2203
2204        for p in range(P):
2205            lam0 = numpy.diag(Lam_data[0,p])
2206
2207            E[p] += lam0;  E[p] = (E[p].T - lam0).T
2208
2209        with numpy.errstate(divide='ignore'):
2210            H = 1./E
2211        for p in range(P):
2212            b = b_list[p]
2213            for nb in range(b.size-1):
2214                H[p, b[nb]:b[nb+1], b[nb]:b[nb+1] ] = 0
2215
2216
2217        # STEP 2: compute Lbar +  H * Q^T Qbar
2218        cls._dot(cls._transpose(Q_data), Qbar_data, out = tmp1)
2219        tmp1[...] *= H[...]
2220        tmp1[...] += Lambar_data[...]
2221
2222        # STEP 3: compute Q ( Lbar +  H * Q^T Qbar ) Q^T
2223        cls._dot(Q_data, tmp1, out = tmp2)
2224        cls._dot(tmp2, cls._transpose(Q_data), out = tmp1)
2225
2226        Abar_data += tmp1
2227
2228        return out
2229
2230
2231
2232
2233
2234
2235    @classmethod
2236    def _qr_pullback(cls, Qbar_data, Rbar_data, A_data, Q_data, R_data, out = None):
2237        """
2238        computes the pullback of the qr decomposition (Q,R) = qr(A)    <===>    QR = A
2239
2240            A_data      (D,P,M,N) array             regular matrix
2241            Q_data      (D,P,M,K) array             orthogonal vectors Q_1,...,Q_K
2242            R_data      (D,P,K,N) array             upper triagonal matrix
2243
2244            where K = min(M,N)
2245
2246        """
2247
2248        # check if the output array is provided
2249        if out is None:
2250            raise NotImplementedError('need to implement that...')
2251        Abar_data = out
2252
2253        DT,P,M,N = numpy.shape(A_data)
2254        K = min(M,N)
2255
2256        if M < N:
2257            A1_data = A_data[:,:,:,:M]
2258            A2_data = A_data[:,:,:,M:]
2259            R1_data = R_data[:,:,:,:M]
2260            R2_data = R_data[:,:,:,M:]
2261
2262            A1bar_data = Abar_data[:,:,:,:M]
2263            A2bar_data = Abar_data[:,:,:,M:]
2264            R1bar_data = Rbar_data[:,:,:,:M]
2265            R2bar_data = Rbar_data[:,:,:,M:]
2266
2267            Qbar_data = Qbar_data.copy()
2268
2269            Qbar_data += cls._dot(A2_data, cls._transpose(R2bar_data), out = numpy.zeros((DT,P,M,M)))
2270            A2bar_data += cls._dot(Q_data, R2bar_data, out = numpy.zeros((DT,P,M,N-M)))
2271            cls._qr_rectangular_pullback(Qbar_data, R1bar_data, A1_data, Q_data, R1_data, out = A1bar_data)
2272
2273        else:
2274            cls._qr_rectangular_pullback( Qbar_data, Rbar_data, A_data, Q_data, R_data, out = out)
2275
2276    @classmethod
2277    def _qr_rectangular_pullback(cls, Qbar_data, Rbar_data, A_data, Q_data, R_data, out = None):
2278        """
2279        assumes that A.shape = M,N with M >= N
2280        """
2281
2282        if out is None:
2283            raise NotImplementedError('need to implement that...')
2284
2285        Abar_data = out
2286
2287        A_shp = A_data.shape
2288        D,P,M,N = A_shp
2289
2290
2291        if M < N:
2292            raise NotImplementedError('supplied matrix has more columns that rows')
2293
2294        # allocate temporary storage and temporary matrices
2295        tmp1 = numpy.zeros((D,P,N,N))
2296        tmp2 = numpy.zeros((D,P,N,N))
2297        tmp3 = numpy.zeros((D,P,M,N))
2298        tmp4 = numpy.zeros((D,P,M,N))
2299        PL  = numpy.array([[ c < r for c in range(N)] for r in range(N)],dtype=float)
2300
2301        # STEP 1: compute V = Qbar^T Q - R Rbar^T
2302        cls._dot( cls._transpose(Qbar_data), Q_data, out = tmp1)
2303        cls._dot( R_data, cls._transpose(Rbar_data), out = tmp2)
2304        tmp1[...] -= tmp2[...]
2305
2306        # STEP 2: compute PL * (V.T - V)
2307        tmp2[...]  = cls._transpose(tmp1)
2308        tmp2[...] -= tmp1[...]
2309
2310        cls._mul_non_UTPM_x(PL, tmp2, out = tmp1)
2311
2312        # STEP 3: compute PL * (V.T - V) R^{-T}
2313
2314        # compute rank of the zero'th coefficient
2315        rank_list = []
2316        for p in range(P):
2317            rank = 0
2318            # print 'p=',p
2319            for n in range(N):
2320                # print 'R_data[0,p,n,n]=',R_data[0,p,n,n]
2321                if abs(R_data[0,p,n,n]) > 1e-16:
2322                    rank += 1
2323            rank_list.append(rank)
2324
2325        # FIXME: assuming the same rank for all zero'th coefficient
2326        # print 'rank = ', rank
2327        # print tmp1
2328        # print 'tmp1[:,:,:rank,:rank]=',tmp1[:,:,:rank,:rank]
2329        tmp2[...] = 0
2330        cls._solve(R_data[:,:,:rank,:rank], cls._transpose(tmp1[:,:,:rank,:rank]), out = tmp2[:,:,:rank,:rank])
2331        tmp2 = tmp2.transpose((0,1,3,2))
2332
2333        # print 'Rbar_data=',Rbar_data[...]
2334
2335        # STEP 4: compute Rbar + PL * (V.T - V) R^{-T}
2336        tmp2[...] += Rbar_data[...]
2337
2338        # tmp2[...,rank:,:] = 0
2339
2340        # STEP 5: compute Q ( Rbar + PL * (V.T - V) R^{-T} )
2341        cls._dot( Q_data, tmp2, out = tmp3)
2342        Abar_data += tmp3
2343
2344        # print 'Abar_data = ', Abar_data
2345
2346        if M > N:
2347            # STEP 6: compute (Qbar - Q Q^T Qbar) R^{-T}
2348            cls._dot( cls._transpose(Q_data), Qbar_data, out = tmp1)
2349            cls._dot( Q_data, tmp1, out = tmp3)
2350            tmp3 *= -1.
2351            tmp3 += Qbar_data
2352            cls._solve(R_data, cls._transpose(tmp3), out = cls._transpose(tmp4))
2353            Abar_data += tmp4
2354
2355        return out
2356
2357    @classmethod
2358    def _transpose(cls, a_data, axes = None):
2359        """Permute the dimensions of UTPM data"""
2360        if axes is not None:
2361            raise NotImplementedError('should implement that')
2362
2363        Nshp = len(a_data.shape)
2364        axes_ids = tuple(range(2,Nshp)[::-1])
2365        return numpy.transpose(a_data,axes=(0,1) + axes_ids)
2366
2367    @classmethod
2368    def _diag(cls, v_data, k = 0, out = None):
2369        """Extract a diagonal or construct  diagonal UTPM data"""
2370
2371        if numpy.ndim(v_data) == 3:
2372            D,P,N = v_data.shape
2373            if out is None:
2374                out = numpy.zeros((D,P,N,N),dtype=v_data.dtype)
2375            else:
2376                out[...] = 0.
2377
2378            for d in range(D):
2379                for p in range(P):
2380                    out[d,p] = numpy.diag(v_data[d,p])
2381
2382            return out
2383
2384        else:
2385            D,P,M,N = v_data.shape
2386            if out is None:
2387                out = numpy.zeros((D,P,N),dtype=v_data.dtype)
2388
2389            for d in range(D):
2390                for p in range(P):
2391                    out[d,p] = numpy.diag(v_data[d,p])
2392
2393            return out
2394
2395    @classmethod
2396    def _diag_pullback(cls, ybar_data, x_data, y_data, k = 0, out = None):
2397        """computes tr(ybar.T, dy) = tr(xbar.T,dx)
2398        where y = diag(x)
2399        """
2400
2401        if out is None:
2402            raise NotImplementedError('should implement that')
2403
2404        if k != 0:
2405            raise NotImplementedError('should implement that')
2406
2407
2408        D,P = x_data.shape[:2]
2409        for d in range(D):
2410            for p in range(P):
2411                out[d,p] += numpy.diag(ybar_data[d,p])
2412
2413        return out
2414