1# Authors: Pierre Ablin <pierre.ablin@inria.fr>
2#          Alexandre Gramfort <alexandre.gramfort@inria.fr>
3#
4# License: BSD (3-clause)
5import warnings
6from itertools import product
7
8import numpy as np
9from numpy.testing import assert_allclose
10
11from picard import picard, permute, amari_distance
12from picard.densities import Tanh, Exp, Cube, check_density
13
14
15def test_dimension_reduction():
16    N, T = 5, 1000
17    n_components = 3
18    rng = np.random.RandomState(42)
19    S = rng.laplace(size=(N, T))
20    A = rng.randn(N, N)
21    X = np.dot(A, S)
22    K, W, Y = picard(X, n_components=n_components, ortho=False,
23                     random_state=rng, max_iter=2)
24    assert K.shape == (n_components, N)
25    assert W.shape == (n_components, n_components)
26    assert Y.shape, (n_components, T)
27    with warnings.catch_warnings(record=True) as w:
28        K, W, Y = picard(X, n_components=n_components, ortho=False,
29                         whiten=False, max_iter=1)
30        assert len(w) == 2
31
32
33def test_dots():
34    N, T = 5, 100
35    rng = np.random.RandomState(42)
36    S = rng.laplace(size=(N, T))
37    A = rng.randn(N, N)
38    X = np.dot(A, S)
39    n_components = [N, 3]
40    tf = [False, True]
41    w_inits = [None, 'id']
42    for n_component, ortho, whiten, w_init in product(n_components, tf, tf,
43                                                      w_inits):
44        if w_init == 'id':
45            if whiten:
46                w_init = np.eye(n_component)
47            else:
48                w_init = np.eye(N)
49        with warnings.catch_warnings(record=True):
50            K, W, Y, X_mean = picard(X, ortho=ortho, whiten=whiten,
51                                     return_X_mean=True, w_init=w_init,
52                                     n_components=n_component,
53                                     random_state=rng, max_iter=2,
54                                     verbose=False)
55        if not whiten:
56            K = np.eye(N)
57        if ortho and whiten:
58            assert_allclose(Y.dot(Y.T) / T, np.eye(n_component), atol=1e-8)
59        Y_prime = np.dot(W, K).dot(X - X_mean[:, None])
60        assert_allclose(Y, Y_prime, atol=1e-7)
61
62
63def test_pre_fastica():
64    N, T = 3, 1000
65    rng = np.random.RandomState(42)
66    names = ['tanh', 'cube']
67    for j, fun in enumerate([Tanh(params=dict(alpha=0.5)), 'cube']):
68        if j == 0:
69            S = rng.laplace(size=(N, T))
70        else:
71            S = rng.uniform(low=-1, high=1, size=(N, T))
72        A = rng.randn(N, N)
73        X = np.dot(A, S)
74        K, W, Y = picard(X, fun=fun, ortho=False, random_state=0,
75                         fastica_it=10)
76        if fun == 'tanh':
77            fun = Tanh()
78        elif fun == 'exp':
79            fun = Exp()
80        elif fun == 'cube':
81            fun = Cube()
82        # Get the final gradient norm
83        psiY = fun.score_and_der(Y)[0]
84        G = np.inner(psiY, Y) / float(T) - np.eye(N)
85        err_msg = 'fun %s, gradient norm greater than tol' % names[j]
86        assert_allclose(G, np.zeros((N, N)), atol=1e-7,
87                        err_msg=err_msg)
88        assert Y.shape == X.shape
89        assert W.shape == A.shape
90        assert K.shape == A.shape
91        WA = W.dot(K).dot(A)
92        WA = permute(WA)  # Permute and scale
93        err_msg = 'fun %s, wrong unmixing matrix' % names[j]
94        assert_allclose(WA, np.eye(N), rtol=0, atol=1e-1,
95                        err_msg=err_msg)
96
97
98def test_picard():
99    N, T = 3, 1000
100    rng = np.random.RandomState(42)
101    names = ['tanh', 'cube']
102    for j, fun in enumerate([Tanh(params=dict(alpha=0.5)), 'cube']):
103        if j == 0:
104            S = rng.laplace(size=(N, T))
105        else:
106            S = rng.uniform(low=-1, high=1, size=(N, T))
107        A = rng.randn(N, N)
108        X = np.dot(A, S)
109        K, W, Y = picard(X, fun=fun, ortho=False, random_state=0)
110        if fun == 'tanh':
111            fun = Tanh()
112        elif fun == 'exp':
113            fun = Exp()
114        elif fun == 'cube':
115            fun = Cube()
116        # Get the final gradient norm
117        psiY = fun.score_and_der(Y)[0]
118        G = np.inner(psiY, Y) / float(T) - np.eye(N)
119        err_msg = 'fun %s, gradient norm greater than tol' % names[j]
120        assert_allclose(G, np.zeros((N, N)), atol=1e-7,
121                        err_msg=err_msg)
122        assert Y.shape == X.shape
123        assert W.shape == A.shape
124        assert K.shape == A.shape
125        WA = W.dot(K).dot(A)
126        WA = permute(WA)  # Permute and scale
127        err_msg = 'fun %s, wrong unmixing matrix' % names[j]
128        assert_allclose(WA, np.eye(N), rtol=0, atol=1e-1,
129                        err_msg=err_msg)
130
131
132def test_extended():
133    N, T = 4, 2000
134    n = N // 2
135    rng = np.random.RandomState(42)
136
137    S = np.concatenate((rng.laplace(size=(n, T)),
138                        rng.uniform(low=-1, high=1, size=(n, T))),
139                       axis=0)
140    print(S.shape)
141    A = rng.randn(N, N)
142    X = np.dot(A, S)
143    K, W, Y = picard(X, ortho=False, random_state=0,
144                     extended=True)
145    assert Y.shape == X.shape
146    assert W.shape == A.shape
147    assert K.shape == A.shape
148    WA = W.dot(K).dot(A)
149    WA = permute(WA)  # Permute and scale
150    err_msg = 'wrong unmixing matrix'
151    assert_allclose(WA, np.eye(N), rtol=0, atol=1e-1,
152                    err_msg=err_msg)
153
154
155def test_shift():
156    N, T = 5, 1000
157    rng = np.random.RandomState(42)
158    S = rng.laplace(size=(N, T))
159    A = rng.randn(N, N)
160    offset = rng.randn(N)
161    X = np.dot(A, S) + offset[:, None]
162    _, W, Y, X_mean = picard(X, ortho=False, whiten=False,
163                             return_X_mean=True, random_state=rng)
164    assert_allclose(offset, X_mean, rtol=0, atol=0.2)
165    WA = W.dot(A)
166    WA = permute(WA)
167    assert_allclose(WA, np.eye(N), rtol=0, atol=0.2)
168    _, W, Y, X_mean = picard(X, ortho=False, whiten=False,
169                             centering=False, return_X_mean=True,
170                             random_state=rng)
171    assert_allclose(X_mean, 0)
172
173
174def test_picardo():
175    N, T = 3, 2000
176    rng = np.random.RandomState(4)
177    S = rng.laplace(size=(N, T))
178    A = rng.randn(N, N)
179    X = np.dot(A, S)
180    names = ['tanh', 'exp', 'cube']
181    for fastica_it in [None, 2]:
182        for fun in names:
183            print(fun)
184            K, W, Y = picard(X, fun=fun, ortho=True, random_state=rng,
185                             fastica_it=fastica_it, verbose=True,
186                             extended=True)
187            if fun == 'tanh':
188                fun = Tanh()
189            elif fun == 'exp':
190                fun = Exp()
191            elif fun == 'cube':
192                fun = Cube()
193            # Get the final gradient norm
194            psiY = fun.score_and_der(Y)[0]
195            G = np.inner(psiY, Y) / float(T) - np.eye(N)
196            G = (G - G.T) / 2.  # take skew-symmetric part
197            err_msg = 'fun %s, gradient norm greater than tol' % fun
198            assert_allclose(G, np.zeros((N, N)), atol=1e-7,
199                            err_msg=err_msg)
200            assert Y.shape == X.shape
201            assert W.shape == A.shape
202            assert K.shape == A.shape
203            WA = W.dot(K).dot(A)
204            WA = permute(WA)  # Permute and scale
205            err_msg = 'fun %s, wrong unmixing matrix' % fun
206            assert_allclose(WA, np.eye(N), rtol=0, atol=0.1,
207                            err_msg=err_msg)
208
209
210def test_bad_custom_density():
211
212    class CustomDensity(object):
213        def log_lik(self, Y):
214            return Y ** 4 / 4
215
216        def score_and_der(self, Y):
217            return Y ** 3, 3 * Y ** 2 + 2.
218
219    fun = CustomDensity()
220    X = np.random.randn(2, 10)
221    try:
222        picard(X, fun=fun, random_state=0)
223    except AssertionError:
224        pass
225    else:
226        raise(AssertionError, 'Bad function undetected')
227
228
229def test_fun():
230    for fun in [Tanh(), Exp(), Cube()]:
231        check_density(fun)
232
233
234def test_no_regression():
235    n_tests = 10
236    baseline = {}
237    baseline['lap', True] = 17.
238    baseline['lap', False] = 23.
239    baseline['gauss', True] = 58.
240    baseline['gauss', False] = 60.
241    N, T = 10, 1000
242    for mode in ['lap', 'gauss']:
243        for ortho in [True, False]:
244            n_iters = []
245            for i in range(n_tests):
246                rng = np.random.RandomState(i)
247                if mode == 'lap':
248                    S = rng.laplace(size=(N, T))
249                else:
250                    S = rng.randn(N, T)
251                A = rng.randn(N, N)
252                X = np.dot(A, S)
253                _, _, _, n_iter = picard(X, return_n_iter=True,
254                                         ortho=ortho, random_state=rng)
255                n_iters.append(n_iter)
256            n_mean = np.mean(n_iters)
257            nb_mean = baseline[mode, ortho]
258            err_msg = 'mode=%s, ortho=%s. %d iterations, expecting <%d.'
259            assert n_mean < nb_mean, err_msg % (mode, ortho, n_mean, nb_mean)
260
261
262def test_amari_distance():
263    p = 3
264    rng = np.random.RandomState(0)
265    A = rng.randn(p, p)
266    W = np.linalg.pinv(A)
267    scale = rng.randn(p)
268    perm = np.argsort(rng.randn(p))
269    W = W[perm]
270    W *= scale[:, None]
271    assert amari_distance(W, A) < 1e-6
272