1"""Utilities for enumeration of finite and countably infinite sets. 2""" 3from __future__ import absolute_import, division, print_function 4### 5# Countable iteration 6 7# Simplifies some calculations 8class Aleph0(int): 9 _singleton = None 10 def __new__(type): 11 if type._singleton is None: 12 type._singleton = int.__new__(type) 13 return type._singleton 14 def __repr__(self): return '<aleph0>' 15 def __str__(self): return 'inf' 16 17 def __cmp__(self, b): 18 return 1 19 20 def __sub__(self, b): 21 raise ValueError("Cannot subtract aleph0") 22 __rsub__ = __sub__ 23 24 def __add__(self, b): 25 return self 26 __radd__ = __add__ 27 28 def __mul__(self, b): 29 if b == 0: return b 30 return self 31 __rmul__ = __mul__ 32 33 def __floordiv__(self, b): 34 if b == 0: raise ZeroDivisionError 35 return self 36 __rfloordiv__ = __floordiv__ 37 __truediv__ = __floordiv__ 38 __rtuediv__ = __floordiv__ 39 __div__ = __floordiv__ 40 __rdiv__ = __floordiv__ 41 42 def __pow__(self, b): 43 if b == 0: return 1 44 return self 45aleph0 = Aleph0() 46 47def base(line): 48 return line*(line+1)//2 49 50def pairToN(pair): 51 x,y = pair 52 line,index = x+y,y 53 return base(line)+index 54 55def getNthPairInfo(N): 56 # Avoid various singularities 57 if N==0: 58 return (0,0) 59 60 # Gallop to find bounds for line 61 line = 1 62 next = 2 63 while base(next)<=N: 64 line = next 65 next = line << 1 66 67 # Binary search for starting line 68 lo = line 69 hi = line<<1 70 while lo + 1 != hi: 71 #assert base(lo) <= N < base(hi) 72 mid = (lo + hi)>>1 73 if base(mid)<=N: 74 lo = mid 75 else: 76 hi = mid 77 78 line = lo 79 return line, N - base(line) 80 81def getNthPair(N): 82 line,index = getNthPairInfo(N) 83 return (line - index, index) 84 85def getNthPairBounded(N,W=aleph0,H=aleph0,useDivmod=False): 86 """getNthPairBounded(N, W, H) -> (x, y) 87 88 Return the N-th pair such that 0 <= x < W and 0 <= y < H.""" 89 90 if W <= 0 or H <= 0: 91 raise ValueError("Invalid bounds") 92 elif N >= W*H: 93 raise ValueError("Invalid input (out of bounds)") 94 95 # Simple case... 96 if W is aleph0 and H is aleph0: 97 return getNthPair(N) 98 99 # Otherwise simplify by assuming W < H 100 if H < W: 101 x,y = getNthPairBounded(N,H,W,useDivmod=useDivmod) 102 return y,x 103 104 if useDivmod: 105 return N%W,N//W 106 else: 107 # Conceptually we want to slide a diagonal line across a 108 # rectangle. This gives more interesting results for large 109 # bounds than using divmod. 110 111 # If in lower left, just return as usual 112 cornerSize = base(W) 113 if N < cornerSize: 114 return getNthPair(N) 115 116 # Otherwise if in upper right, subtract from corner 117 if H is not aleph0: 118 M = W*H - N - 1 119 if M < cornerSize: 120 x,y = getNthPair(M) 121 return (W-1-x,H-1-y) 122 123 # Otherwise, compile line and index from number of times we 124 # wrap. 125 N = N - cornerSize 126 index,offset = N%W,N//W 127 # p = (W-1, 1+offset) + (-1,1)*index 128 return (W-1-index, 1+offset+index) 129def getNthPairBoundedChecked(N,W=aleph0,H=aleph0,useDivmod=False,GNP=getNthPairBounded): 130 x,y = GNP(N,W,H,useDivmod) 131 assert 0 <= x < W and 0 <= y < H 132 return x,y 133 134def getNthNTuple(N, W, H=aleph0, useLeftToRight=False): 135 """getNthNTuple(N, W, H) -> (x_0, x_1, ..., x_W) 136 137 Return the N-th W-tuple, where for 0 <= x_i < H.""" 138 139 if useLeftToRight: 140 elts = [None]*W 141 for i in range(W): 142 elts[i],N = getNthPairBounded(N, H) 143 return tuple(elts) 144 else: 145 if W==0: 146 return () 147 elif W==1: 148 return (N,) 149 elif W==2: 150 return getNthPairBounded(N, H, H) 151 else: 152 LW,RW = W//2, W - (W//2) 153 L,R = getNthPairBounded(N, H**LW, H**RW) 154 return (getNthNTuple(L,LW,H=H,useLeftToRight=useLeftToRight) + 155 getNthNTuple(R,RW,H=H,useLeftToRight=useLeftToRight)) 156def getNthNTupleChecked(N, W, H=aleph0, useLeftToRight=False, GNT=getNthNTuple): 157 t = GNT(N,W,H,useLeftToRight) 158 assert len(t) == W 159 for i in t: 160 assert i < H 161 return t 162 163def getNthTuple(N, maxSize=aleph0, maxElement=aleph0, useDivmod=False, useLeftToRight=False): 164 """getNthTuple(N, maxSize, maxElement) -> x 165 166 Return the N-th tuple where len(x) < maxSize and for y in x, 0 <= 167 y < maxElement.""" 168 169 # All zero sized tuples are isomorphic, don't ya know. 170 if N == 0: 171 return () 172 N -= 1 173 if maxElement is not aleph0: 174 if maxSize is aleph0: 175 raise NotImplementedError('Max element size without max size unhandled') 176 bounds = [maxElement**i for i in range(1, maxSize+1)] 177 S,M = getNthPairVariableBounds(N, bounds) 178 else: 179 S,M = getNthPairBounded(N, maxSize, useDivmod=useDivmod) 180 return getNthNTuple(M, S+1, maxElement, useLeftToRight=useLeftToRight) 181def getNthTupleChecked(N, maxSize=aleph0, maxElement=aleph0, 182 useDivmod=False, useLeftToRight=False, GNT=getNthTuple): 183 # FIXME: maxsize is inclusive 184 t = GNT(N,maxSize,maxElement,useDivmod,useLeftToRight) 185 assert len(t) <= maxSize 186 for i in t: 187 assert i < maxElement 188 return t 189 190def getNthPairVariableBounds(N, bounds): 191 """getNthPairVariableBounds(N, bounds) -> (x, y) 192 193 Given a finite list of bounds (which may be finite or aleph0), 194 return the N-th pair such that 0 <= x < len(bounds) and 0 <= y < 195 bounds[x].""" 196 197 if not bounds: 198 raise ValueError("Invalid bounds") 199 if not (0 <= N < sum(bounds)): 200 raise ValueError("Invalid input (out of bounds)") 201 202 level = 0 203 active = list(range(len(bounds))) 204 active.sort(key=lambda i: bounds[i]) 205 prevLevel = 0 206 for i,index in enumerate(active): 207 level = bounds[index] 208 W = len(active) - i 209 if level is aleph0: 210 H = aleph0 211 else: 212 H = level - prevLevel 213 levelSize = W*H 214 if N<levelSize: # Found the level 215 idelta,delta = getNthPairBounded(N, W, H) 216 return active[i+idelta],prevLevel+delta 217 else: 218 N -= levelSize 219 prevLevel = level 220 else: 221 raise RuntimError("Unexpected loop completion") 222 223def getNthPairVariableBoundsChecked(N, bounds, GNVP=getNthPairVariableBounds): 224 x,y = GNVP(N,bounds) 225 assert 0 <= x < len(bounds) and 0 <= y < bounds[x] 226 return (x,y) 227 228### 229 230def testPairs(): 231 W = 3 232 H = 6 233 a = [[' ' for x in range(10)] for y in range(10)] 234 b = [[' ' for x in range(10)] for y in range(10)] 235 for i in range(min(W*H,40)): 236 x,y = getNthPairBounded(i,W,H) 237 x2,y2 = getNthPairBounded(i,W,H,useDivmod=True) 238 print(i,(x,y),(x2,y2)) 239 a[y][x] = '%2d'%i 240 b[y2][x2] = '%2d'%i 241 242 print('-- a --') 243 for ln in a[::-1]: 244 if ''.join(ln).strip(): 245 print(' '.join(ln)) 246 print('-- b --') 247 for ln in b[::-1]: 248 if ''.join(ln).strip(): 249 print(' '.join(ln)) 250 251def testPairsVB(): 252 bounds = [2,2,4,aleph0,5,aleph0] 253 a = [[' ' for x in range(15)] for y in range(15)] 254 b = [[' ' for x in range(15)] for y in range(15)] 255 for i in range(min(sum(bounds),40)): 256 x,y = getNthPairVariableBounds(i, bounds) 257 print(i,(x,y)) 258 a[y][x] = '%2d'%i 259 260 print('-- a --') 261 for ln in a[::-1]: 262 if ''.join(ln).strip(): 263 print(' '.join(ln)) 264 265### 266 267# Toggle to use checked versions of enumeration routines. 268if False: 269 getNthPairVariableBounds = getNthPairVariableBoundsChecked 270 getNthPairBounded = getNthPairBoundedChecked 271 getNthNTuple = getNthNTupleChecked 272 getNthTuple = getNthTupleChecked 273 274if __name__ == '__main__': 275 testPairs() 276 277 testPairsVB() 278 279