1from __future__ import print_function
2
3# Python 3 does not have long, only int
4try:
5    long
6except NameError:
7    long = int
8
9import numpy as np
10from math import atan2
11from cylp.py import Constants
12from operator import mul
13
14def sign(x):
15    if x > 0 or (x == 0 and atan2(x, -1.) > 0.):
16        return 1
17    else:
18        return -1
19
20def get_cs(w1, w2):
21    omega = float(sign(w1) * (w1**2 + w2**2)**0.5)
22    c = w1 / omega
23    s = w2 / omega
24    return c, s
25
26def givens(n, i, j, w1, w2):
27    g = np.identity(n, float)
28    if abs(w1) < Constants.EPSILON and abs(w2) < Constants.EPSILON:
29        return g
30    c, s = get_cs(w1, w2)
31    g[i,i], g[j,j], g[i,j], g[j,i] = c, -c, s, s
32    return g
33
34def applyGivens(vec):
35    'Applies dim-1 givens matrices so that vec contains only one non-zero element'
36    v = vec.copy()
37    dim = v.shape[0]
38    Q_bar = np.matrix(np.identity(dim, float))
39    Qlist = []
40    for i in range(dim - 1):
41        if vec[dim-i-1] != 0:
42            Q = givens(dim, dim-i-2, dim-i-1, v[dim-i-2,0], v[dim-i-1,0])
43            Qlist.append(Q)
44            v = Q * v
45            Q_bar = Q * Q_bar
46    return Qlist, Q_bar, v[0,0]
47
48def UH2UT(mat):
49    m = mat.copy()
50    nrows, ncols = mat.shape
51    Q_bar = np.matrix(np.identity(nrows, float))
52    for i in range(ncols):
53        Q = givens(nrows, i, i+1, m[i,i], m[i+1,i])
54        m = Q * m
55        Q_bar = Q * Q_bar
56    #we need to remove the last line because it's all zero (and we need a square matrix)
57    return Q_bar, m[:nrows-1, :]
58
59
60## This part is for defining the decorators 'precondtion' and 'postcondition' and 'conditions'
61
62__all__ = ['precondition', 'postcondition', 'conditions']
63
64DEFAULT_ON = True
65
66def precondition(precondition, use_conditions=DEFAULT_ON):
67    return conditions(precondition, None, use_conditions)
68
69def postcondition(postcondition, use_conditions=DEFAULT_ON):
70    return conditions(None, postcondition, use_conditions)
71
72class conditions(object):
73    __slots__ = ('__precondition', '__postcondition')
74
75    def __init__(self, pre, post, use_conditions=DEFAULT_ON):
76        if not use_conditions:
77            pre, post = None, None
78
79        self.__precondition = pre
80        self.__postcondition = post
81
82    def __call__(self, function):
83        # combine recursive wrappers (@precondition + @postcondition == @conditions)
84        pres = set((self.__precondition,))
85        posts = set((self.__postcondition,))
86
87        # unwrap function, collect distinct pre-/post conditions
88        while type(function) is FunctionWrapper:
89            pres.add(function._pre)
90            posts.add(function._post)
91            function = function._func
92
93        # filter out None conditions and build pairs of pre- and postconditions
94        conditions = map(None, filter(None, pres), filter(None, posts))
95
96        # add a wrapper for each pair (note that 'conditions' may be empty)
97        for pre, post in conditions:
98            function = FunctionWrapper(pre, post, function)
99
100        return function
101
102class FunctionWrapper(object):
103    def __init__(self, precondition, postcondition, function):
104        self._pre = precondition
105        self._post = postcondition
106        self._func = function
107
108    def __call__(self, *args, **kwargs):
109        precondition = self._pre
110        postcondition = self._post
111
112        if precondition:
113            precondition(*args, **kwargs)
114        result = self._func(*args, **kwargs)
115        if postcondition:
116            postcondition(result, *args, **kwargs)
117        return result
118
119
120class Ind:
121    def __init__(self, key, dim):
122        '''
123        Create an instance of Ind using *key* that can be
124        an integer, a slice, a list, or a numpy array.
125        '''
126        if isinstance(key, slice):
127            sl = key
128            if sl.stop and (sl.start > dim or sl.start >= sl.stop):
129                raise Exception('Indexing problem: %s, dim=%d:' % (str(sl), dim))
130
131            if  not sl.stop or sl.stop > dim:
132                stop = dim
133            else:
134                stop = sl.stop
135            if not sl.start:
136                start = 0
137            else:
138                start = sl.start
139            self.indices = range(start, stop)
140            self.dim = dim
141        elif isinstance(key, (int, long)):
142            if key >= dim:
143                raise Exception('Index (%d) out of range (%d)' % (key, dim))
144            self.indices = [key]
145            self.dim = dim
146        elif isinstance(key, (list, np.ndarray)):
147            self.indices = key
148            self.dim = dim
149            # Here to avoid checking all the inds, I suppose that normally
150            # the list of indices is sorted. So just check the last number
151            if key[-1] >= dim:
152                raise Exception('Index (%d) out of range (%d)' % (key[-1], dim))
153        else:
154            raise Exception('Error indexing with unrecognized type %s' % key.__class__)
155
156
157    def __repr__(self):
158        return '(%d, %d / %d)' % (self.start, self.stop, self.dim)
159
160def getIndS(inds):
161    n = len(inds)
162    if n == 1:
163        return inds[0].start
164    prod = inds[0].stop
165    for i in range(1, n):
166        prod *= inds[i].dim
167    return prod + getIndS(inds[1:])
168
169def getMultiDimMatrixIndex(inds, res=[]):
170    n = len(inds)
171    r = inds[0].indices
172    #r = range(inds[0].start, inds[0].stop)
173    if n == 1:
174        return r
175    l = []
176    for i in r:
177        prod = i
178        for k in range(1, n):
179            prod *= inds[k].dim
180        rest = getMultiDimMatrixIndex(inds[1:], res)
181        l += res + [prod + rs for rs in rest]
182    return l
183
184def getTupleIndex(ind, dims):
185    if isinstance(dims, int):
186        return [ind] if ind < dims else -1
187    n = len(dims)
188    if ind > reduce(mul, dims):
189        return -1
190    if n == 1:
191        return [ind]
192    #return getTupleIndex(ind / dims[-1], dims[:-1]) + [ind % dims[-1]]
193    ret = []
194    for i in range(n):
195        d = dims[n - i - 1]
196        ret.insert(0, ind % d)
197        ind /= d
198        #return getTupleIndex(ind / dims[-1], dims[:-1]) + [ind % dims[-1]]
199    return ret
200
201if __name__ == '__main__':
202    i1 = Ind(slice(1, 4), 5)
203    i2 = Ind(slice(2, 4), 6)
204    #i3 = Ind(slice(2, 5), 7)
205    i3 = Ind(np.array([1, 4, 6]), 7)
206
207
208    inds = getMultiDimMatrixIndex([i1, i2, i3])
209
210    for i in inds:
211        print(getTupleIndex(i, (5, 6, 7)), i)
212
213    print(getTupleIndex(8, 8))
214