1import numpy as np
2from nilearn.decoding.fista import mfista
3from nilearn.decoding.proximal_operators import _prox_l1
4from nilearn.decoding.objective_functions import (
5    _squared_loss,
6    _logistic,
7    _squared_loss_grad,
8    _logistic_loss_lipschitz_constant,
9    spectral_norm_squared)
10from nilearn.decoding.fista import _check_lipschitz_continuous
11
12
13def test_logistic_lipschitz(n_samples=4, n_features=2, random_state=42):
14    rng = np.random.RandomState(random_state)
15
16    for scaling in np.logspace(-3, 3, num=7):
17        X = rng.randn(n_samples, n_features) * scaling
18        y = rng.randn(n_samples)
19        n_features = X.shape[1]
20
21        L = _logistic_loss_lipschitz_constant(X)
22        _check_lipschitz_continuous(lambda w: _logistic(
23            X, y, w), n_features + 1, L)
24
25
26def test_squared_loss_lipschitz(n_samples=4, n_features=2, random_state=42):
27    rng = np.random.RandomState(random_state)
28
29    for scaling in np.logspace(-3, 3, num=7):
30        X = rng.randn(n_samples, n_features) * scaling
31        y = rng.randn(n_samples)
32        n_features = X.shape[1]
33
34        L = spectral_norm_squared(X)
35        _check_lipschitz_continuous(lambda w: _squared_loss_grad(
36            X, y, w), n_features, L)
37
38
39def test_input_args_and_kwargs():
40    rng = np.random.RandomState(42)
41    p = 125
42    noise_std = 1e-1
43    sig = np.zeros(p)
44    sig[[0, 2, 13, 4, 25, 32, 80, 89, 91, 93, -1]] = 1
45    sig[:6] = 2
46    sig[-7:] = 2
47    sig[60:75] = 1
48    y = sig + noise_std * rng.randn(*sig.shape)
49    X = np.eye(p)
50    mask = np.ones((p,)).astype(bool)
51    alpha = .01
52    alpha_ = alpha * X.shape[0]
53    l1_ratio = .2
54    l1_weight = alpha_ * l1_ratio
55    f1 = lambda w: _squared_loss(X, y, w, compute_grad=False)
56    f1_grad = lambda w: _squared_loss(X, y, w, compute_grad=True,
57                                      compute_energy=False)
58    f2_prox = lambda w, l, *args, **kwargs: (_prox_l1(w, l * l1_weight),
59                                             dict(converged=True))
60    total_energy = lambda w: f1(w) + l1_weight * np.sum(np.abs(w))
61    for cb_retval in [0, 1]:
62        for verbose in [0, 1]:
63            for dgap_factor in [1., None]:
64                best_w, objective, init = mfista(
65                    f1_grad, f2_prox, total_energy, 1., p,
66                    dgap_factor=dgap_factor,
67                    callback=lambda _: cb_retval, verbose=verbose,
68                    max_iter=100)
69                assert best_w.shape == mask.shape
70                assert isinstance(objective, list)
71                assert isinstance(init, dict)
72                for key in ["w", "t", "dgap_tol", "stepsize"]:
73                    assert key in init
74