1# This file is part of the Astrometry.net suite.
2# Licensed under a 3-clause BSD style license - see LICENSE
3from __future__ import print_function
4import sys
5
6#from numpy import array, matrix, linalg
7from numpy import *
8from numpy.random import *
9from numpy.linalg import *
10from matplotlib.pylab import figure, plot, xlabel, ylabel, loglog, clf
11from matplotlib.pylab import semilogy
12#from pylab import *
13
14class Transform(object):
15    scale = None
16    rotation = None
17    incenter = None
18    outcenter = None
19
20    def apply(self, X):
21        #print X
22        dx = X - self.incenter
23        #print dx
24        dx = dx * self.scale
25        #print dx
26        dx = self.rotation * dx
27        #print dx
28        dx = dx + self.outcenter
29        #print dx
30        return dx
31
32    def __str__(self):
33        s = ('<Transform: tin (%f,%f) scale (%f) rot (%f, %f; %f, %f) tout (%f, %f)>' %
34             (self.incenter[0], self.incenter[1], self.scale,
35              self.rotation[0,0], self.rotation[0,1], self.rotation[1,0], self.rotation[1,1],
36              self.outcenter[0], self.outcenter[1]))
37        return s
38
39def procrustes(X, Y):
40    T = Transform()
41    sx = X.shape
42    if sx[0] != 2:
43        print('X must be 2xN')
44    sy = Y.shape
45    if sy[0] != 2:
46        print('Y must be 2xN')
47    N = sx[1]
48
49    mx = X.mean(axis=1).reshape(2,1)
50    my = Y.mean(axis=1).reshape(2,1)
51    #print 'mean(X) is\n', mx
52    #print 'mean(Y) is\n', my
53    T.incenter = mx
54    T.outcenter = my
55
56    #print 'X-mx is\n', X-mx
57    #print '(X-mx)^2 is\n', (X-mx)*(X-mx)
58    varx = sum(sum((X - mx)*(X - mx)), axis=1)
59    vary = sum(sum((Y - my)*(Y - my)), axis=1)
60    #print 'var(X) is', varx
61    #print 'var(Y) is', vary
62    T.scale = sqrt(vary / varx)
63    #print 'scale is', T.scale
64
65    C = zeros((2,2))
66    for i in [0,1]:
67        for j in [0,1]:
68            C[i,j] = sum((X[i,:] - mx[i]) * (Y[j,:] - my[j]))
69    #print 'cov is\n', C
70
71    U,S,V = svd(C)
72    U = matrix(U)
73    V = matrix(V)
74
75    #print 'U is\n', U
76    #print 'U\' is\n', U.transpose()
77    #print 'V is\n', V
78    R = V * U.transpose()
79    #print 'R is\n', R
80    T.rotation = R
81    return T
82
83
84def test_procrustes_1():
85    # Create a Transform, apply it to some points, then run procrustes to see if we
86    # recover the Transform exactly.
87    t1 = Transform()
88    t1.scale = 3.0
89    A = 48.0 * pi/180.0
90    t1.rotation = matrix([[sin(A), cos(A)], [-cos(A), sin(A)]])
91    t1.incenter = array([42, 500]).reshape(2,1)
92    t1.outcenter = array([600, -12]).reshape(2,1)
93
94    N = 4
95    pts = zeros((2,N))
96    tpts = zeros((2,N))
97    for i in range(N):
98        pts[0,i] = t1.incenter[0] + ((i % 2) - 0.5) * 200
99        pts[1,i] = t1.incenter[1] + (((i/2) % 2) - 0.5) * 200
100
101    for i in range(N):
102        pt = pts[:,i].reshape(2,1)
103        tpts[:,i] = t1.apply(pt).reshape(1,2)
104
105    t2 = procrustes(pts, tpts)
106
107    print('pts:', pts)
108    print('tpts:', tpts)
109
110    print('t1 is', t1)
111    print('t2 is', t2)
112
113
114def draw_sample(inoise=1, fnoise=0, iqnoise=-1,
115                dimquads=4, quadscale=100, imgsize=1000,
116                Rsteps=10, Asteps=36):
117
118    # Stars that compose the field quad.
119    fquad = zeros((2,dimquads))
120    fquad[0,0] = imgsize/2 - quadscale/2
121    fquad[1,0] = imgsize/2
122    fquad[0,1] = imgsize/2 + quadscale/2
123    fquad[1,1] = imgsize/2
124    for i in range(2, dimquads):
125        fquad[0,i] = imgsize/2 + randn(1) * quadscale
126        fquad[1,i] = imgsize/2 + randn(1) * quadscale
127
128    # Index quad is field quad plus jitter.
129    iquad = fquad + randn(*fquad.shape)
130
131    # Solve for transformation
132    T = procrustes(iquad, fquad)
133
134    # Put the index quad stars through the transformation
135    itrans = zeros(fquad.shape)
136    for i in range(dimquads):
137        fq = fquad[:,i].reshape(2,1)
138        itrans[:,i] = T.apply(fq).transpose()
139
140    # Field quad center...
141    qc = mean(fquad, axis=1)
142
143    # Sample stars on a R^2, theta grid.
144    #rads = sqrt((array(range(Rsteps))+1) / float(Rsteps)) * imgsize/2
145    N = Rsteps * Asteps
146    rads = sqrt((array(range(Rsteps))+0.5) / float(Rsteps)) * imgsize/2
147    thetas = array(range(Asteps)) / float(Asteps) * 2.0 * pi
148    fstars = zeros((2,N))
149    for r in range(Rsteps):
150        for a in range(Asteps):
151            fstars[0, r*Asteps + a] = sin(thetas[a]) * rads[r] + qc[0]
152            fstars[1, r*Asteps + a] = cos(thetas[a]) * rads[r] + qc[1]
153    # Put them through the transformation...
154    istars = zeros((2,N))
155    for i in range(N):
156        fs = fstars[:,i].reshape(2,1)
157        istars[:,i] = T.apply(fs).transpose()
158
159    R = sqrt((fstars[0,:] - qc[0])**2 + (fstars[1,:] - qc[1])**2)
160    E = sqrt(sum((fstars - istars)**2, axis=0))
161
162    # Fit to a linear model...
163    xfit = R**2
164    yfit = E**2
165    A = zeros((2,N))
166    A[0,:] = 1
167    A[1,:] = xfit.transpose()
168    (C,resids,rank,s) = lstsq(A.transpose(), yfit)
169
170    return (fquad, iquad, T, itrans, qc, fstars, istars,
171            R, E, C)
172
173if __name__ == '__main__':
174
175    test_procrustes_1()
176    sys.exit(0)
177
178    #N = 1000
179    N = 100
180    C = zeros((2,N))
181    QD = zeros((N))
182    for i in range(N):
183        (fquad, iquad, T, itrans, qc, fstars, istars, R, E, c) = draw_sample()
184        C[:,i] = c
185        QD[i] = sqrt(sum((iquad - fquad)**2) / 4.0)
186
187    C0 = C[0,:]
188    C1 = C[1,:]
189
190    figure(1)
191    clf()
192    loglog(C0, C1, 'b.')
193    xlabel('E^2 vs R^2 - Fit coefficient 0')
194    ylabel('E^2 vs R^2 - Fit coefficient 1')
195
196    figure(2)
197    clf()
198    semilogy(QD, C1, 'b.')
199    xlabel('Field-to-Index Quad Mean Distance')
200    ylabel('E^2-vs-R^2 fit linear coefficient')
201
202    #semilogy(QD, C1, 'bo')
203    #xlabel('Quad Distance')
204    #ylabel('C1')
205
206
207    #figure(1)
208    #I=[0,2,1,3,0];
209    #plot(fquad[0,I], fquad[1,I], 'bo-', itrans[0,I], itrans[1,I], 'ro-')
210
211    #figure(2)
212    #plot(fstars[0,:], fstars[1,:], 'b.', istars[0,:], istars[1,:], 'r.')
213
214    #figure(1)
215    #I=[0,2,1,3,0];
216    #plot(fquad[0,I], fquad[1,I], 'bo-',
217    #     itrans[0,I], itrans[1,I], 'ro-',
218    #     fstars[0,:], fstars[1,:], 'b.',
219    #     istars[0,:], istars[1,:], 'r.')
220
221    #figure(2)
222    #plot(R, E, 'r.')
223    #xlabel('R')
224    #ylabel('E')
225
226    #figure(3)
227    #plot(R**2, E**2, 'r.')
228    #xlabel('R^2')
229    #ylabel('E^2')
230
231    #print 'Fit coefficients are', C
232
233    #figure(2)
234    #xplot = array(range(101)) / 100.0 * max(xfit)
235    #plot(R**2, E**2, 'r.',
236    #     xplot, C[0] + C[1]*xplot, 'b-')
237    #xlabel('R^2')
238    #ylabel('E^2')
239
240    #show()
241
242