1# Authors: Pierre Ablin <pierre.ablin@inria.fr> 2# Alexandre Gramfort <alexandre.gramfort@inria.fr> 3# 4# License: BSD (3-clause) 5import warnings 6from itertools import product 7 8import numpy as np 9from numpy.testing import assert_allclose 10 11from picard import picard, permute, amari_distance 12from picard.densities import Tanh, Exp, Cube, check_density 13 14 15def test_dimension_reduction(): 16 N, T = 5, 1000 17 n_components = 3 18 rng = np.random.RandomState(42) 19 S = rng.laplace(size=(N, T)) 20 A = rng.randn(N, N) 21 X = np.dot(A, S) 22 K, W, Y = picard(X, n_components=n_components, ortho=False, 23 random_state=rng, max_iter=2) 24 assert K.shape == (n_components, N) 25 assert W.shape == (n_components, n_components) 26 assert Y.shape, (n_components, T) 27 with warnings.catch_warnings(record=True) as w: 28 K, W, Y = picard(X, n_components=n_components, ortho=False, 29 whiten=False, max_iter=1) 30 assert len(w) == 2 31 32 33def test_dots(): 34 N, T = 5, 100 35 rng = np.random.RandomState(42) 36 S = rng.laplace(size=(N, T)) 37 A = rng.randn(N, N) 38 X = np.dot(A, S) 39 n_components = [N, 3] 40 tf = [False, True] 41 w_inits = [None, 'id'] 42 for n_component, ortho, whiten, w_init in product(n_components, tf, tf, 43 w_inits): 44 if w_init == 'id': 45 if whiten: 46 w_init = np.eye(n_component) 47 else: 48 w_init = np.eye(N) 49 with warnings.catch_warnings(record=True): 50 K, W, Y, X_mean = picard(X, ortho=ortho, whiten=whiten, 51 return_X_mean=True, w_init=w_init, 52 n_components=n_component, 53 random_state=rng, max_iter=2, 54 verbose=False) 55 if not whiten: 56 K = np.eye(N) 57 if ortho and whiten: 58 assert_allclose(Y.dot(Y.T) / T, np.eye(n_component), atol=1e-8) 59 Y_prime = np.dot(W, K).dot(X - X_mean[:, None]) 60 assert_allclose(Y, Y_prime, atol=1e-7) 61 62 63def test_pre_fastica(): 64 N, T = 3, 1000 65 rng = np.random.RandomState(42) 66 names = ['tanh', 'cube'] 67 for j, fun in enumerate([Tanh(params=dict(alpha=0.5)), 'cube']): 68 if j == 0: 69 S = rng.laplace(size=(N, T)) 70 else: 71 S = rng.uniform(low=-1, high=1, size=(N, T)) 72 A = rng.randn(N, N) 73 X = np.dot(A, S) 74 K, W, Y = picard(X, fun=fun, ortho=False, random_state=0, 75 fastica_it=10) 76 if fun == 'tanh': 77 fun = Tanh() 78 elif fun == 'exp': 79 fun = Exp() 80 elif fun == 'cube': 81 fun = Cube() 82 # Get the final gradient norm 83 psiY = fun.score_and_der(Y)[0] 84 G = np.inner(psiY, Y) / float(T) - np.eye(N) 85 err_msg = 'fun %s, gradient norm greater than tol' % names[j] 86 assert_allclose(G, np.zeros((N, N)), atol=1e-7, 87 err_msg=err_msg) 88 assert Y.shape == X.shape 89 assert W.shape == A.shape 90 assert K.shape == A.shape 91 WA = W.dot(K).dot(A) 92 WA = permute(WA) # Permute and scale 93 err_msg = 'fun %s, wrong unmixing matrix' % names[j] 94 assert_allclose(WA, np.eye(N), rtol=0, atol=1e-1, 95 err_msg=err_msg) 96 97 98def test_picard(): 99 N, T = 3, 1000 100 rng = np.random.RandomState(42) 101 names = ['tanh', 'cube'] 102 for j, fun in enumerate([Tanh(params=dict(alpha=0.5)), 'cube']): 103 if j == 0: 104 S = rng.laplace(size=(N, T)) 105 else: 106 S = rng.uniform(low=-1, high=1, size=(N, T)) 107 A = rng.randn(N, N) 108 X = np.dot(A, S) 109 K, W, Y = picard(X, fun=fun, ortho=False, random_state=0) 110 if fun == 'tanh': 111 fun = Tanh() 112 elif fun == 'exp': 113 fun = Exp() 114 elif fun == 'cube': 115 fun = Cube() 116 # Get the final gradient norm 117 psiY = fun.score_and_der(Y)[0] 118 G = np.inner(psiY, Y) / float(T) - np.eye(N) 119 err_msg = 'fun %s, gradient norm greater than tol' % names[j] 120 assert_allclose(G, np.zeros((N, N)), atol=1e-7, 121 err_msg=err_msg) 122 assert Y.shape == X.shape 123 assert W.shape == A.shape 124 assert K.shape == A.shape 125 WA = W.dot(K).dot(A) 126 WA = permute(WA) # Permute and scale 127 err_msg = 'fun %s, wrong unmixing matrix' % names[j] 128 assert_allclose(WA, np.eye(N), rtol=0, atol=1e-1, 129 err_msg=err_msg) 130 131 132def test_extended(): 133 N, T = 4, 2000 134 n = N // 2 135 rng = np.random.RandomState(42) 136 137 S = np.concatenate((rng.laplace(size=(n, T)), 138 rng.uniform(low=-1, high=1, size=(n, T))), 139 axis=0) 140 print(S.shape) 141 A = rng.randn(N, N) 142 X = np.dot(A, S) 143 K, W, Y = picard(X, ortho=False, random_state=0, 144 extended=True) 145 assert Y.shape == X.shape 146 assert W.shape == A.shape 147 assert K.shape == A.shape 148 WA = W.dot(K).dot(A) 149 WA = permute(WA) # Permute and scale 150 err_msg = 'wrong unmixing matrix' 151 assert_allclose(WA, np.eye(N), rtol=0, atol=1e-1, 152 err_msg=err_msg) 153 154 155def test_shift(): 156 N, T = 5, 1000 157 rng = np.random.RandomState(42) 158 S = rng.laplace(size=(N, T)) 159 A = rng.randn(N, N) 160 offset = rng.randn(N) 161 X = np.dot(A, S) + offset[:, None] 162 _, W, Y, X_mean = picard(X, ortho=False, whiten=False, 163 return_X_mean=True, random_state=rng) 164 assert_allclose(offset, X_mean, rtol=0, atol=0.2) 165 WA = W.dot(A) 166 WA = permute(WA) 167 assert_allclose(WA, np.eye(N), rtol=0, atol=0.2) 168 _, W, Y, X_mean = picard(X, ortho=False, whiten=False, 169 centering=False, return_X_mean=True, 170 random_state=rng) 171 assert_allclose(X_mean, 0) 172 173 174def test_picardo(): 175 N, T = 3, 2000 176 rng = np.random.RandomState(4) 177 S = rng.laplace(size=(N, T)) 178 A = rng.randn(N, N) 179 X = np.dot(A, S) 180 names = ['tanh', 'exp', 'cube'] 181 for fastica_it in [None, 2]: 182 for fun in names: 183 print(fun) 184 K, W, Y = picard(X, fun=fun, ortho=True, random_state=rng, 185 fastica_it=fastica_it, verbose=True, 186 extended=True) 187 if fun == 'tanh': 188 fun = Tanh() 189 elif fun == 'exp': 190 fun = Exp() 191 elif fun == 'cube': 192 fun = Cube() 193 # Get the final gradient norm 194 psiY = fun.score_and_der(Y)[0] 195 G = np.inner(psiY, Y) / float(T) - np.eye(N) 196 G = (G - G.T) / 2. # take skew-symmetric part 197 err_msg = 'fun %s, gradient norm greater than tol' % fun 198 assert_allclose(G, np.zeros((N, N)), atol=1e-7, 199 err_msg=err_msg) 200 assert Y.shape == X.shape 201 assert W.shape == A.shape 202 assert K.shape == A.shape 203 WA = W.dot(K).dot(A) 204 WA = permute(WA) # Permute and scale 205 err_msg = 'fun %s, wrong unmixing matrix' % fun 206 assert_allclose(WA, np.eye(N), rtol=0, atol=0.1, 207 err_msg=err_msg) 208 209 210def test_bad_custom_density(): 211 212 class CustomDensity(object): 213 def log_lik(self, Y): 214 return Y ** 4 / 4 215 216 def score_and_der(self, Y): 217 return Y ** 3, 3 * Y ** 2 + 2. 218 219 fun = CustomDensity() 220 X = np.random.randn(2, 10) 221 try: 222 picard(X, fun=fun, random_state=0) 223 except AssertionError: 224 pass 225 else: 226 raise(AssertionError, 'Bad function undetected') 227 228 229def test_fun(): 230 for fun in [Tanh(), Exp(), Cube()]: 231 check_density(fun) 232 233 234def test_no_regression(): 235 n_tests = 10 236 baseline = {} 237 baseline['lap', True] = 17. 238 baseline['lap', False] = 23. 239 baseline['gauss', True] = 58. 240 baseline['gauss', False] = 60. 241 N, T = 10, 1000 242 for mode in ['lap', 'gauss']: 243 for ortho in [True, False]: 244 n_iters = [] 245 for i in range(n_tests): 246 rng = np.random.RandomState(i) 247 if mode == 'lap': 248 S = rng.laplace(size=(N, T)) 249 else: 250 S = rng.randn(N, T) 251 A = rng.randn(N, N) 252 X = np.dot(A, S) 253 _, _, _, n_iter = picard(X, return_n_iter=True, 254 ortho=ortho, random_state=rng) 255 n_iters.append(n_iter) 256 n_mean = np.mean(n_iters) 257 nb_mean = baseline[mode, ortho] 258 err_msg = 'mode=%s, ortho=%s. %d iterations, expecting <%d.' 259 assert n_mean < nb_mean, err_msg % (mode, ortho, n_mean, nb_mean) 260 261 262def test_amari_distance(): 263 p = 3 264 rng = np.random.RandomState(0) 265 A = rng.randn(p, p) 266 W = np.linalg.pinv(A) 267 scale = rng.randn(p) 268 perm = np.argsort(rng.randn(p)) 269 W = W[perm] 270 W *= scale[:, None] 271 assert amari_distance(W, A) < 1e-6 272