1# This file is part of the Astrometry.net suite. 2# Licensed under a 3-clause BSD style license - see LICENSE 3from __future__ import print_function 4import sys 5 6#from numpy import array, matrix, linalg 7from numpy import * 8from numpy.random import * 9from numpy.linalg import * 10from matplotlib.pylab import figure, plot, xlabel, ylabel, loglog, clf 11from matplotlib.pylab import semilogy 12#from pylab import * 13 14class Transform(object): 15 scale = None 16 rotation = None 17 incenter = None 18 outcenter = None 19 20 def apply(self, X): 21 #print X 22 dx = X - self.incenter 23 #print dx 24 dx = dx * self.scale 25 #print dx 26 dx = self.rotation * dx 27 #print dx 28 dx = dx + self.outcenter 29 #print dx 30 return dx 31 32 def __str__(self): 33 s = ('<Transform: tin (%f,%f) scale (%f) rot (%f, %f; %f, %f) tout (%f, %f)>' % 34 (self.incenter[0], self.incenter[1], self.scale, 35 self.rotation[0,0], self.rotation[0,1], self.rotation[1,0], self.rotation[1,1], 36 self.outcenter[0], self.outcenter[1])) 37 return s 38 39def procrustes(X, Y): 40 T = Transform() 41 sx = X.shape 42 if sx[0] != 2: 43 print('X must be 2xN') 44 sy = Y.shape 45 if sy[0] != 2: 46 print('Y must be 2xN') 47 N = sx[1] 48 49 mx = X.mean(axis=1).reshape(2,1) 50 my = Y.mean(axis=1).reshape(2,1) 51 #print 'mean(X) is\n', mx 52 #print 'mean(Y) is\n', my 53 T.incenter = mx 54 T.outcenter = my 55 56 #print 'X-mx is\n', X-mx 57 #print '(X-mx)^2 is\n', (X-mx)*(X-mx) 58 varx = sum(sum((X - mx)*(X - mx)), axis=1) 59 vary = sum(sum((Y - my)*(Y - my)), axis=1) 60 #print 'var(X) is', varx 61 #print 'var(Y) is', vary 62 T.scale = sqrt(vary / varx) 63 #print 'scale is', T.scale 64 65 C = zeros((2,2)) 66 for i in [0,1]: 67 for j in [0,1]: 68 C[i,j] = sum((X[i,:] - mx[i]) * (Y[j,:] - my[j])) 69 #print 'cov is\n', C 70 71 U,S,V = svd(C) 72 U = matrix(U) 73 V = matrix(V) 74 75 #print 'U is\n', U 76 #print 'U\' is\n', U.transpose() 77 #print 'V is\n', V 78 R = V * U.transpose() 79 #print 'R is\n', R 80 T.rotation = R 81 return T 82 83 84def test_procrustes_1(): 85 # Create a Transform, apply it to some points, then run procrustes to see if we 86 # recover the Transform exactly. 87 t1 = Transform() 88 t1.scale = 3.0 89 A = 48.0 * pi/180.0 90 t1.rotation = matrix([[sin(A), cos(A)], [-cos(A), sin(A)]]) 91 t1.incenter = array([42, 500]).reshape(2,1) 92 t1.outcenter = array([600, -12]).reshape(2,1) 93 94 N = 4 95 pts = zeros((2,N)) 96 tpts = zeros((2,N)) 97 for i in range(N): 98 pts[0,i] = t1.incenter[0] + ((i % 2) - 0.5) * 200 99 pts[1,i] = t1.incenter[1] + (((i/2) % 2) - 0.5) * 200 100 101 for i in range(N): 102 pt = pts[:,i].reshape(2,1) 103 tpts[:,i] = t1.apply(pt).reshape(1,2) 104 105 t2 = procrustes(pts, tpts) 106 107 print('pts:', pts) 108 print('tpts:', tpts) 109 110 print('t1 is', t1) 111 print('t2 is', t2) 112 113 114def draw_sample(inoise=1, fnoise=0, iqnoise=-1, 115 dimquads=4, quadscale=100, imgsize=1000, 116 Rsteps=10, Asteps=36): 117 118 # Stars that compose the field quad. 119 fquad = zeros((2,dimquads)) 120 fquad[0,0] = imgsize/2 - quadscale/2 121 fquad[1,0] = imgsize/2 122 fquad[0,1] = imgsize/2 + quadscale/2 123 fquad[1,1] = imgsize/2 124 for i in range(2, dimquads): 125 fquad[0,i] = imgsize/2 + randn(1) * quadscale 126 fquad[1,i] = imgsize/2 + randn(1) * quadscale 127 128 # Index quad is field quad plus jitter. 129 iquad = fquad + randn(*fquad.shape) 130 131 # Solve for transformation 132 T = procrustes(iquad, fquad) 133 134 # Put the index quad stars through the transformation 135 itrans = zeros(fquad.shape) 136 for i in range(dimquads): 137 fq = fquad[:,i].reshape(2,1) 138 itrans[:,i] = T.apply(fq).transpose() 139 140 # Field quad center... 141 qc = mean(fquad, axis=1) 142 143 # Sample stars on a R^2, theta grid. 144 #rads = sqrt((array(range(Rsteps))+1) / float(Rsteps)) * imgsize/2 145 N = Rsteps * Asteps 146 rads = sqrt((array(range(Rsteps))+0.5) / float(Rsteps)) * imgsize/2 147 thetas = array(range(Asteps)) / float(Asteps) * 2.0 * pi 148 fstars = zeros((2,N)) 149 for r in range(Rsteps): 150 for a in range(Asteps): 151 fstars[0, r*Asteps + a] = sin(thetas[a]) * rads[r] + qc[0] 152 fstars[1, r*Asteps + a] = cos(thetas[a]) * rads[r] + qc[1] 153 # Put them through the transformation... 154 istars = zeros((2,N)) 155 for i in range(N): 156 fs = fstars[:,i].reshape(2,1) 157 istars[:,i] = T.apply(fs).transpose() 158 159 R = sqrt((fstars[0,:] - qc[0])**2 + (fstars[1,:] - qc[1])**2) 160 E = sqrt(sum((fstars - istars)**2, axis=0)) 161 162 # Fit to a linear model... 163 xfit = R**2 164 yfit = E**2 165 A = zeros((2,N)) 166 A[0,:] = 1 167 A[1,:] = xfit.transpose() 168 (C,resids,rank,s) = lstsq(A.transpose(), yfit) 169 170 return (fquad, iquad, T, itrans, qc, fstars, istars, 171 R, E, C) 172 173if __name__ == '__main__': 174 175 test_procrustes_1() 176 sys.exit(0) 177 178 #N = 1000 179 N = 100 180 C = zeros((2,N)) 181 QD = zeros((N)) 182 for i in range(N): 183 (fquad, iquad, T, itrans, qc, fstars, istars, R, E, c) = draw_sample() 184 C[:,i] = c 185 QD[i] = sqrt(sum((iquad - fquad)**2) / 4.0) 186 187 C0 = C[0,:] 188 C1 = C[1,:] 189 190 figure(1) 191 clf() 192 loglog(C0, C1, 'b.') 193 xlabel('E^2 vs R^2 - Fit coefficient 0') 194 ylabel('E^2 vs R^2 - Fit coefficient 1') 195 196 figure(2) 197 clf() 198 semilogy(QD, C1, 'b.') 199 xlabel('Field-to-Index Quad Mean Distance') 200 ylabel('E^2-vs-R^2 fit linear coefficient') 201 202 #semilogy(QD, C1, 'bo') 203 #xlabel('Quad Distance') 204 #ylabel('C1') 205 206 207 #figure(1) 208 #I=[0,2,1,3,0]; 209 #plot(fquad[0,I], fquad[1,I], 'bo-', itrans[0,I], itrans[1,I], 'ro-') 210 211 #figure(2) 212 #plot(fstars[0,:], fstars[1,:], 'b.', istars[0,:], istars[1,:], 'r.') 213 214 #figure(1) 215 #I=[0,2,1,3,0]; 216 #plot(fquad[0,I], fquad[1,I], 'bo-', 217 # itrans[0,I], itrans[1,I], 'ro-', 218 # fstars[0,:], fstars[1,:], 'b.', 219 # istars[0,:], istars[1,:], 'r.') 220 221 #figure(2) 222 #plot(R, E, 'r.') 223 #xlabel('R') 224 #ylabel('E') 225 226 #figure(3) 227 #plot(R**2, E**2, 'r.') 228 #xlabel('R^2') 229 #ylabel('E^2') 230 231 #print 'Fit coefficients are', C 232 233 #figure(2) 234 #xplot = array(range(101)) / 100.0 * max(xfit) 235 #plot(R**2, E**2, 'r.', 236 # xplot, C[0] + C[1]*xplot, 'b-') 237 #xlabel('R^2') 238 #ylabel('E^2') 239 240 #show() 241 242