1import numpy as np 2from nilearn.decoding.fista import mfista 3from nilearn.decoding.proximal_operators import _prox_l1 4from nilearn.decoding.objective_functions import ( 5 _squared_loss, 6 _logistic, 7 _squared_loss_grad, 8 _logistic_loss_lipschitz_constant, 9 spectral_norm_squared) 10from nilearn.decoding.fista import _check_lipschitz_continuous 11 12 13def test_logistic_lipschitz(n_samples=4, n_features=2, random_state=42): 14 rng = np.random.RandomState(random_state) 15 16 for scaling in np.logspace(-3, 3, num=7): 17 X = rng.randn(n_samples, n_features) * scaling 18 y = rng.randn(n_samples) 19 n_features = X.shape[1] 20 21 L = _logistic_loss_lipschitz_constant(X) 22 _check_lipschitz_continuous(lambda w: _logistic( 23 X, y, w), n_features + 1, L) 24 25 26def test_squared_loss_lipschitz(n_samples=4, n_features=2, random_state=42): 27 rng = np.random.RandomState(random_state) 28 29 for scaling in np.logspace(-3, 3, num=7): 30 X = rng.randn(n_samples, n_features) * scaling 31 y = rng.randn(n_samples) 32 n_features = X.shape[1] 33 34 L = spectral_norm_squared(X) 35 _check_lipschitz_continuous(lambda w: _squared_loss_grad( 36 X, y, w), n_features, L) 37 38 39def test_input_args_and_kwargs(): 40 rng = np.random.RandomState(42) 41 p = 125 42 noise_std = 1e-1 43 sig = np.zeros(p) 44 sig[[0, 2, 13, 4, 25, 32, 80, 89, 91, 93, -1]] = 1 45 sig[:6] = 2 46 sig[-7:] = 2 47 sig[60:75] = 1 48 y = sig + noise_std * rng.randn(*sig.shape) 49 X = np.eye(p) 50 mask = np.ones((p,)).astype(bool) 51 alpha = .01 52 alpha_ = alpha * X.shape[0] 53 l1_ratio = .2 54 l1_weight = alpha_ * l1_ratio 55 f1 = lambda w: _squared_loss(X, y, w, compute_grad=False) 56 f1_grad = lambda w: _squared_loss(X, y, w, compute_grad=True, 57 compute_energy=False) 58 f2_prox = lambda w, l, *args, **kwargs: (_prox_l1(w, l * l1_weight), 59 dict(converged=True)) 60 total_energy = lambda w: f1(w) + l1_weight * np.sum(np.abs(w)) 61 for cb_retval in [0, 1]: 62 for verbose in [0, 1]: 63 for dgap_factor in [1., None]: 64 best_w, objective, init = mfista( 65 f1_grad, f2_prox, total_energy, 1., p, 66 dgap_factor=dgap_factor, 67 callback=lambda _: cb_retval, verbose=verbose, 68 max_iter=100) 69 assert best_w.shape == mask.shape 70 assert isinstance(objective, list) 71 assert isinstance(init, dict) 72 for key in ["w", "t", "dgap_tol", "stepsize"]: 73 assert key in init 74