1"""Utilities for enumeration of finite and countably infinite sets.
2"""
3###
4# Countable iteration
5
6# Simplifies some calculations
7class Aleph0(int):
8    _singleton = None
9    def __new__(type):
10        if type._singleton is None:
11            type._singleton = int.__new__(type)
12        return type._singleton
13    def __repr__(self): return '<aleph0>'
14    def __str__(self): return 'inf'
15
16    def __cmp__(self, b):
17        return 1
18
19    def __sub__(self, b):
20        raise ValueError,"Cannot subtract aleph0"
21    __rsub__ = __sub__
22
23    def __add__(self, b):
24        return self
25    __radd__ = __add__
26
27    def __mul__(self, b):
28        if b == 0: return b
29        return self
30    __rmul__ = __mul__
31
32    def __floordiv__(self, b):
33        if b == 0: raise ZeroDivisionError
34        return self
35    __rfloordiv__ = __floordiv__
36    __truediv__ = __floordiv__
37    __rtuediv__ = __floordiv__
38    __div__ = __floordiv__
39    __rdiv__ = __floordiv__
40
41    def __pow__(self, b):
42        if b == 0: return 1
43        return self
44aleph0 = Aleph0()
45
46def base(line):
47    return line*(line+1)//2
48
49def pairToN((x,y)):
50    line,index = x+y,y
51    return base(line)+index
52
53def getNthPairInfo(N):
54    # Avoid various singularities
55    if N==0:
56        return (0,0)
57
58    # Gallop to find bounds for line
59    line = 1
60    next = 2
61    while base(next)<=N:
62        line = next
63        next = line << 1
64
65    # Binary search for starting line
66    lo = line
67    hi = line<<1
68    while lo + 1 != hi:
69        #assert base(lo) <= N < base(hi)
70        mid = (lo + hi)>>1
71        if base(mid)<=N:
72            lo = mid
73        else:
74            hi = mid
75
76    line = lo
77    return line, N - base(line)
78
79def getNthPair(N):
80    line,index = getNthPairInfo(N)
81    return (line - index, index)
82
83def getNthPairBounded(N,W=aleph0,H=aleph0,useDivmod=False):
84    """getNthPairBounded(N, W, H) -> (x, y)
85
86    Return the N-th pair such that 0 <= x < W and 0 <= y < H."""
87
88    if W <= 0 or H <= 0:
89        raise ValueError,"Invalid bounds"
90    elif N >= W*H:
91        raise ValueError,"Invalid input (out of bounds)"
92
93    # Simple case...
94    if W is aleph0 and H is aleph0:
95        return getNthPair(N)
96
97    # Otherwise simplify by assuming W < H
98    if H < W:
99        x,y = getNthPairBounded(N,H,W,useDivmod=useDivmod)
100        return y,x
101
102    if useDivmod:
103        return N%W,N//W
104    else:
105        # Conceptually we want to slide a diagonal line across a
106        # rectangle. This gives more interesting results for large
107        # bounds than using divmod.
108
109        # If in lower left, just return as usual
110        cornerSize = base(W)
111        if N < cornerSize:
112            return getNthPair(N)
113
114        # Otherwise if in upper right, subtract from corner
115        if H is not aleph0:
116            M = W*H - N - 1
117            if M < cornerSize:
118                x,y = getNthPair(M)
119                return (W-1-x,H-1-y)
120
121        # Otherwise, compile line and index from number of times we
122        # wrap.
123        N = N - cornerSize
124        index,offset = N%W,N//W
125        # p = (W-1, 1+offset) + (-1,1)*index
126        return (W-1-index, 1+offset+index)
127def getNthPairBoundedChecked(N,W=aleph0,H=aleph0,useDivmod=False,GNP=getNthPairBounded):
128    x,y = GNP(N,W,H,useDivmod)
129    assert 0 <= x < W and 0 <= y < H
130    return x,y
131
132def getNthNTuple(N, W, H=aleph0, useLeftToRight=False):
133    """getNthNTuple(N, W, H) -> (x_0, x_1, ..., x_W)
134
135    Return the N-th W-tuple, where for 0 <= x_i < H."""
136
137    if useLeftToRight:
138        elts = [None]*W
139        for i in range(W):
140            elts[i],N = getNthPairBounded(N, H)
141        return tuple(elts)
142    else:
143        if W==0:
144            return ()
145        elif W==1:
146            return (N,)
147        elif W==2:
148            return getNthPairBounded(N, H, H)
149        else:
150            LW,RW = W//2, W - (W//2)
151            L,R = getNthPairBounded(N, H**LW, H**RW)
152            return (getNthNTuple(L,LW,H=H,useLeftToRight=useLeftToRight) +
153                    getNthNTuple(R,RW,H=H,useLeftToRight=useLeftToRight))
154def getNthNTupleChecked(N, W, H=aleph0, useLeftToRight=False, GNT=getNthNTuple):
155    t = GNT(N,W,H,useLeftToRight)
156    assert len(t) == W
157    for i in t:
158        assert i < H
159    return t
160
161def getNthTuple(N, maxSize=aleph0, maxElement=aleph0, useDivmod=False, useLeftToRight=False):
162    """getNthTuple(N, maxSize, maxElement) -> x
163
164    Return the N-th tuple where len(x) < maxSize and for y in x, 0 <=
165    y < maxElement."""
166
167    # All zero sized tuples are isomorphic, don't ya know.
168    if N == 0:
169        return ()
170    N -= 1
171    if maxElement is not aleph0:
172        if maxSize is aleph0:
173            raise NotImplementedError,'Max element size without max size unhandled'
174        bounds = [maxElement**i for i in range(1, maxSize+1)]
175        S,M = getNthPairVariableBounds(N, bounds)
176    else:
177        S,M = getNthPairBounded(N, maxSize, useDivmod=useDivmod)
178    return getNthNTuple(M, S+1, maxElement, useLeftToRight=useLeftToRight)
179def getNthTupleChecked(N, maxSize=aleph0, maxElement=aleph0,
180                       useDivmod=False, useLeftToRight=False, GNT=getNthTuple):
181    # FIXME: maxsize is inclusive
182    t = GNT(N,maxSize,maxElement,useDivmod,useLeftToRight)
183    assert len(t) <= maxSize
184    for i in t:
185        assert i < maxElement
186    return t
187
188def getNthPairVariableBounds(N, bounds):
189    """getNthPairVariableBounds(N, bounds) -> (x, y)
190
191    Given a finite list of bounds (which may be finite or aleph0),
192    return the N-th pair such that 0 <= x < len(bounds) and 0 <= y <
193    bounds[x]."""
194
195    if not bounds:
196        raise ValueError,"Invalid bounds"
197    if not (0 <= N < sum(bounds)):
198        raise ValueError,"Invalid input (out of bounds)"
199
200    level = 0
201    active = range(len(bounds))
202    active.sort(key=lambda i: bounds[i])
203    prevLevel = 0
204    for i,index in enumerate(active):
205        level = bounds[index]
206        W = len(active) - i
207        if level is aleph0:
208            H = aleph0
209        else:
210            H = level - prevLevel
211        levelSize = W*H
212        if N<levelSize: # Found the level
213            idelta,delta = getNthPairBounded(N, W, H)
214            return active[i+idelta],prevLevel+delta
215        else:
216            N -= levelSize
217            prevLevel = level
218    else:
219        raise RuntimError,"Unexpected loop completion"
220
221def getNthPairVariableBoundsChecked(N, bounds, GNVP=getNthPairVariableBounds):
222    x,y = GNVP(N,bounds)
223    assert 0 <= x < len(bounds) and 0 <= y < bounds[x]
224    return (x,y)
225
226###
227
228def testPairs():
229    W = 3
230    H = 6
231    a = [['  ' for x in range(10)] for y in range(10)]
232    b = [['  ' for x in range(10)] for y in range(10)]
233    for i in range(min(W*H,40)):
234        x,y = getNthPairBounded(i,W,H)
235        x2,y2 = getNthPairBounded(i,W,H,useDivmod=True)
236        print i,(x,y),(x2,y2)
237        a[y][x] = '%2d'%i
238        b[y2][x2] = '%2d'%i
239
240    print '-- a --'
241    for ln in a[::-1]:
242        if ''.join(ln).strip():
243            print '  '.join(ln)
244    print '-- b --'
245    for ln in b[::-1]:
246        if ''.join(ln).strip():
247            print '  '.join(ln)
248
249def testPairsVB():
250    bounds = [2,2,4,aleph0,5,aleph0]
251    a = [['  ' for x in range(15)] for y in range(15)]
252    b = [['  ' for x in range(15)] for y in range(15)]
253    for i in range(min(sum(bounds),40)):
254        x,y = getNthPairVariableBounds(i, bounds)
255        print i,(x,y)
256        a[y][x] = '%2d'%i
257
258    print '-- a --'
259    for ln in a[::-1]:
260        if ''.join(ln).strip():
261            print '  '.join(ln)
262
263###
264
265# Toggle to use checked versions of enumeration routines.
266if False:
267    getNthPairVariableBounds = getNthPairVariableBoundsChecked
268    getNthPairBounded = getNthPairBoundedChecked
269    getNthNTuple = getNthNTupleChecked
270    getNthTuple = getNthTupleChecked
271
272if __name__ == '__main__':
273    testPairs()
274
275    testPairsVB()
276
277