1"""Utilities for enumeration of finite and countably infinite sets.
2"""
3from __future__ import absolute_import, division, print_function
4###
5# Countable iteration
6
7# Simplifies some calculations
8class Aleph0(int):
9    _singleton = None
10    def __new__(type):
11        if type._singleton is None:
12            type._singleton = int.__new__(type)
13        return type._singleton
14    def __repr__(self): return '<aleph0>'
15    def __str__(self): return 'inf'
16
17    def __cmp__(self, b):
18        return 1
19
20    def __sub__(self, b):
21        raise ValueError("Cannot subtract aleph0")
22    __rsub__ = __sub__
23
24    def __add__(self, b):
25        return self
26    __radd__ = __add__
27
28    def __mul__(self, b):
29        if b == 0: return b
30        return self
31    __rmul__ = __mul__
32
33    def __floordiv__(self, b):
34        if b == 0: raise ZeroDivisionError
35        return self
36    __rfloordiv__ = __floordiv__
37    __truediv__ = __floordiv__
38    __rtuediv__ = __floordiv__
39    __div__ = __floordiv__
40    __rdiv__ = __floordiv__
41
42    def __pow__(self, b):
43        if b == 0: return 1
44        return self
45aleph0 = Aleph0()
46
47def base(line):
48    return line*(line+1)//2
49
50def pairToN(pair):
51    x,y = pair
52    line,index = x+y,y
53    return base(line)+index
54
55def getNthPairInfo(N):
56    # Avoid various singularities
57    if N==0:
58        return (0,0)
59
60    # Gallop to find bounds for line
61    line = 1
62    next = 2
63    while base(next)<=N:
64        line = next
65        next = line << 1
66
67    # Binary search for starting line
68    lo = line
69    hi = line<<1
70    while lo + 1 != hi:
71        #assert base(lo) <= N < base(hi)
72        mid = (lo + hi)>>1
73        if base(mid)<=N:
74            lo = mid
75        else:
76            hi = mid
77
78    line = lo
79    return line, N - base(line)
80
81def getNthPair(N):
82    line,index = getNthPairInfo(N)
83    return (line - index, index)
84
85def getNthPairBounded(N,W=aleph0,H=aleph0,useDivmod=False):
86    """getNthPairBounded(N, W, H) -> (x, y)
87
88    Return the N-th pair such that 0 <= x < W and 0 <= y < H."""
89
90    if W <= 0 or H <= 0:
91        raise ValueError("Invalid bounds")
92    elif N >= W*H:
93        raise ValueError("Invalid input (out of bounds)")
94
95    # Simple case...
96    if W is aleph0 and H is aleph0:
97        return getNthPair(N)
98
99    # Otherwise simplify by assuming W < H
100    if H < W:
101        x,y = getNthPairBounded(N,H,W,useDivmod=useDivmod)
102        return y,x
103
104    if useDivmod:
105        return N%W,N//W
106    else:
107        # Conceptually we want to slide a diagonal line across a
108        # rectangle. This gives more interesting results for large
109        # bounds than using divmod.
110
111        # If in lower left, just return as usual
112        cornerSize = base(W)
113        if N < cornerSize:
114            return getNthPair(N)
115
116        # Otherwise if in upper right, subtract from corner
117        if H is not aleph0:
118            M = W*H - N - 1
119            if M < cornerSize:
120                x,y = getNthPair(M)
121                return (W-1-x,H-1-y)
122
123        # Otherwise, compile line and index from number of times we
124        # wrap.
125        N = N - cornerSize
126        index,offset = N%W,N//W
127        # p = (W-1, 1+offset) + (-1,1)*index
128        return (W-1-index, 1+offset+index)
129def getNthPairBoundedChecked(N,W=aleph0,H=aleph0,useDivmod=False,GNP=getNthPairBounded):
130    x,y = GNP(N,W,H,useDivmod)
131    assert 0 <= x < W and 0 <= y < H
132    return x,y
133
134def getNthNTuple(N, W, H=aleph0, useLeftToRight=False):
135    """getNthNTuple(N, W, H) -> (x_0, x_1, ..., x_W)
136
137    Return the N-th W-tuple, where for 0 <= x_i < H."""
138
139    if useLeftToRight:
140        elts = [None]*W
141        for i in range(W):
142            elts[i],N = getNthPairBounded(N, H)
143        return tuple(elts)
144    else:
145        if W==0:
146            return ()
147        elif W==1:
148            return (N,)
149        elif W==2:
150            return getNthPairBounded(N, H, H)
151        else:
152            LW,RW = W//2, W - (W//2)
153            L,R = getNthPairBounded(N, H**LW, H**RW)
154            return (getNthNTuple(L,LW,H=H,useLeftToRight=useLeftToRight) +
155                    getNthNTuple(R,RW,H=H,useLeftToRight=useLeftToRight))
156def getNthNTupleChecked(N, W, H=aleph0, useLeftToRight=False, GNT=getNthNTuple):
157    t = GNT(N,W,H,useLeftToRight)
158    assert len(t) == W
159    for i in t:
160        assert i < H
161    return t
162
163def getNthTuple(N, maxSize=aleph0, maxElement=aleph0, useDivmod=False, useLeftToRight=False):
164    """getNthTuple(N, maxSize, maxElement) -> x
165
166    Return the N-th tuple where len(x) < maxSize and for y in x, 0 <=
167    y < maxElement."""
168
169    # All zero sized tuples are isomorphic, don't ya know.
170    if N == 0:
171        return ()
172    N -= 1
173    if maxElement is not aleph0:
174        if maxSize is aleph0:
175            raise NotImplementedError('Max element size without max size unhandled')
176        bounds = [maxElement**i for i in range(1, maxSize+1)]
177        S,M = getNthPairVariableBounds(N, bounds)
178    else:
179        S,M = getNthPairBounded(N, maxSize, useDivmod=useDivmod)
180    return getNthNTuple(M, S+1, maxElement, useLeftToRight=useLeftToRight)
181def getNthTupleChecked(N, maxSize=aleph0, maxElement=aleph0,
182                       useDivmod=False, useLeftToRight=False, GNT=getNthTuple):
183    # FIXME: maxsize is inclusive
184    t = GNT(N,maxSize,maxElement,useDivmod,useLeftToRight)
185    assert len(t) <= maxSize
186    for i in t:
187        assert i < maxElement
188    return t
189
190def getNthPairVariableBounds(N, bounds):
191    """getNthPairVariableBounds(N, bounds) -> (x, y)
192
193    Given a finite list of bounds (which may be finite or aleph0),
194    return the N-th pair such that 0 <= x < len(bounds) and 0 <= y <
195    bounds[x]."""
196
197    if not bounds:
198        raise ValueError("Invalid bounds")
199    if not (0 <= N < sum(bounds)):
200        raise ValueError("Invalid input (out of bounds)")
201
202    level = 0
203    active = list(range(len(bounds)))
204    active.sort(key=lambda i: bounds[i])
205    prevLevel = 0
206    for i,index in enumerate(active):
207        level = bounds[index]
208        W = len(active) - i
209        if level is aleph0:
210            H = aleph0
211        else:
212            H = level - prevLevel
213        levelSize = W*H
214        if N<levelSize: # Found the level
215            idelta,delta = getNthPairBounded(N, W, H)
216            return active[i+idelta],prevLevel+delta
217        else:
218            N -= levelSize
219            prevLevel = level
220    else:
221        raise RuntimError("Unexpected loop completion")
222
223def getNthPairVariableBoundsChecked(N, bounds, GNVP=getNthPairVariableBounds):
224    x,y = GNVP(N,bounds)
225    assert 0 <= x < len(bounds) and 0 <= y < bounds[x]
226    return (x,y)
227
228###
229
230def testPairs():
231    W = 3
232    H = 6
233    a = [['  ' for x in range(10)] for y in range(10)]
234    b = [['  ' for x in range(10)] for y in range(10)]
235    for i in range(min(W*H,40)):
236        x,y = getNthPairBounded(i,W,H)
237        x2,y2 = getNthPairBounded(i,W,H,useDivmod=True)
238        print(i,(x,y),(x2,y2))
239        a[y][x] = '%2d'%i
240        b[y2][x2] = '%2d'%i
241
242    print('-- a --')
243    for ln in a[::-1]:
244        if ''.join(ln).strip():
245            print('  '.join(ln))
246    print('-- b --')
247    for ln in b[::-1]:
248        if ''.join(ln).strip():
249            print('  '.join(ln))
250
251def testPairsVB():
252    bounds = [2,2,4,aleph0,5,aleph0]
253    a = [['  ' for x in range(15)] for y in range(15)]
254    b = [['  ' for x in range(15)] for y in range(15)]
255    for i in range(min(sum(bounds),40)):
256        x,y = getNthPairVariableBounds(i, bounds)
257        print(i,(x,y))
258        a[y][x] = '%2d'%i
259
260    print('-- a --')
261    for ln in a[::-1]:
262        if ''.join(ln).strip():
263            print('  '.join(ln))
264
265###
266
267# Toggle to use checked versions of enumeration routines.
268if False:
269    getNthPairVariableBounds = getNthPairVariableBoundsChecked
270    getNthPairBounded = getNthPairBoundedChecked
271    getNthNTuple = getNthNTupleChecked
272    getNthTuple = getNthTupleChecked
273
274if __name__ == '__main__':
275    testPairs()
276
277    testPairsVB()
278
279