1"""
2Python implementation of the fast ICA algorithms.
3
4Reference: Tables 8.3 and 8.4 page 196 in the book:
5Independent Component Analysis, by  Hyvarinen et al.
6"""
7
8# Authors: Pierre Lafaye de Micheaux, Stefan van der Walt, Gael Varoquaux,
9#          Bertrand Thirion, Alexandre Gramfort, Denis A. Engemann
10# License: BSD 3 clause
11
12import warnings
13
14import numpy as np
15from scipy import linalg
16
17from ..base import BaseEstimator, TransformerMixin
18from ..exceptions import ConvergenceWarning
19
20from ..utils import check_array, as_float_array, check_random_state
21from ..utils.validation import check_is_fitted
22from ..utils.validation import FLOAT_DTYPES
23
24__all__ = ["fastica", "FastICA"]
25
26
27def _gs_decorrelation(w, W, j):
28    """
29    Orthonormalize w wrt the first j rows of W.
30
31    Parameters
32    ----------
33    w : ndarray of shape (n,)
34        Array to be orthogonalized
35
36    W : ndarray of shape (p, n)
37        Null space definition
38
39    j : int < p
40        The no of (from the first) rows of Null space W wrt which w is
41        orthogonalized.
42
43    Notes
44    -----
45    Assumes that W is orthogonal
46    w changed in place
47    """
48    w -= np.linalg.multi_dot([w, W[:j].T, W[:j]])
49    return w
50
51
52def _sym_decorrelation(W):
53    """Symmetric decorrelation
54    i.e. W <- (W * W.T) ^{-1/2} * W
55    """
56    s, u = linalg.eigh(np.dot(W, W.T))
57    # u (resp. s) contains the eigenvectors (resp. square roots of
58    # the eigenvalues) of W * W.T
59    return np.linalg.multi_dot([u * (1.0 / np.sqrt(s)), u.T, W])
60
61
62def _ica_def(X, tol, g, fun_args, max_iter, w_init):
63    """Deflationary FastICA using fun approx to neg-entropy function
64
65    Used internally by FastICA.
66    """
67
68    n_components = w_init.shape[0]
69    W = np.zeros((n_components, n_components), dtype=X.dtype)
70    n_iter = []
71
72    # j is the index of the extracted component
73    for j in range(n_components):
74        w = w_init[j, :].copy()
75        w /= np.sqrt((w ** 2).sum())
76
77        for i in range(max_iter):
78            gwtx, g_wtx = g(np.dot(w.T, X), fun_args)
79
80            w1 = (X * gwtx).mean(axis=1) - g_wtx.mean() * w
81
82            _gs_decorrelation(w1, W, j)
83
84            w1 /= np.sqrt((w1 ** 2).sum())
85
86            lim = np.abs(np.abs((w1 * w).sum()) - 1)
87            w = w1
88            if lim < tol:
89                break
90
91        n_iter.append(i + 1)
92        W[j, :] = w
93
94    return W, max(n_iter)
95
96
97def _ica_par(X, tol, g, fun_args, max_iter, w_init):
98    """Parallel FastICA.
99
100    Used internally by FastICA --main loop
101
102    """
103    W = _sym_decorrelation(w_init)
104    del w_init
105    p_ = float(X.shape[1])
106    for ii in range(max_iter):
107        gwtx, g_wtx = g(np.dot(W, X), fun_args)
108        W1 = _sym_decorrelation(np.dot(gwtx, X.T) / p_ - g_wtx[:, np.newaxis] * W)
109        del gwtx, g_wtx
110        # builtin max, abs are faster than numpy counter parts.
111        lim = max(abs(abs(np.diag(np.dot(W1, W.T))) - 1))
112        W = W1
113        if lim < tol:
114            break
115    else:
116        warnings.warn(
117            "FastICA did not converge. Consider increasing "
118            "tolerance or the maximum number of iterations.",
119            ConvergenceWarning,
120        )
121
122    return W, ii + 1
123
124
125# Some standard non-linear functions.
126# XXX: these should be optimized, as they can be a bottleneck.
127def _logcosh(x, fun_args=None):
128    alpha = fun_args.get("alpha", 1.0)  # comment it out?
129
130    x *= alpha
131    gx = np.tanh(x, x)  # apply the tanh inplace
132    g_x = np.empty(x.shape[0])
133    # XXX compute in chunks to avoid extra allocation
134    for i, gx_i in enumerate(gx):  # please don't vectorize.
135        g_x[i] = (alpha * (1 - gx_i ** 2)).mean()
136    return gx, g_x
137
138
139def _exp(x, fun_args):
140    exp = np.exp(-(x ** 2) / 2)
141    gx = x * exp
142    g_x = (1 - x ** 2) * exp
143    return gx, g_x.mean(axis=-1)
144
145
146def _cube(x, fun_args):
147    return x ** 3, (3 * x ** 2).mean(axis=-1)
148
149
150def fastica(
151    X,
152    n_components=None,
153    *,
154    algorithm="parallel",
155    whiten=True,
156    fun="logcosh",
157    fun_args=None,
158    max_iter=200,
159    tol=1e-04,
160    w_init=None,
161    random_state=None,
162    return_X_mean=False,
163    compute_sources=True,
164    return_n_iter=False,
165):
166    """Perform Fast Independent Component Analysis.
167
168    The implementation is based on [1]_.
169
170    Read more in the :ref:`User Guide <ICA>`.
171
172    Parameters
173    ----------
174    X : array-like of shape (n_samples, n_features)
175        Training vector, where `n_samples` is the number of samples and
176        `n_features` is the number of features.
177
178    n_components : int, default=None
179        Number of components to extract. If None no dimension reduction
180        is performed.
181
182    algorithm : {'parallel', 'deflation'}, default='parallel'
183        Apply a parallel or deflational FASTICA algorithm.
184
185    whiten : bool, default=True
186        If True perform an initial whitening of the data.
187        If False, the data is assumed to have already been
188        preprocessed: it should be centered, normed and white.
189        Otherwise you will get incorrect results.
190        In this case the parameter n_components will be ignored.
191
192    fun : {'logcosh', 'exp', 'cube'} or callable, default='logcosh'
193        The functional form of the G function used in the
194        approximation to neg-entropy. Could be either 'logcosh', 'exp',
195        or 'cube'.
196        You can also provide your own function. It should return a tuple
197        containing the value of the function, and of its derivative, in the
198        point. The derivative should be averaged along its last dimension.
199        Example:
200
201        def my_g(x):
202            return x ** 3, np.mean(3 * x ** 2, axis=-1)
203
204    fun_args : dict, default=None
205        Arguments to send to the functional form.
206        If empty or None and if fun='logcosh', fun_args will take value
207        {'alpha' : 1.0}
208
209    max_iter : int, default=200
210        Maximum number of iterations to perform.
211
212    tol : float, default=1e-04
213        A positive scalar giving the tolerance at which the
214        un-mixing matrix is considered to have converged.
215
216    w_init : ndarray of shape (n_components, n_components), default=None
217        Initial un-mixing array of dimension (n.comp,n.comp).
218        If None (default) then an array of normal r.v.'s is used.
219
220    random_state : int, RandomState instance or None, default=None
221        Used to initialize ``w_init`` when not specified, with a
222        normal distribution. Pass an int, for reproducible results
223        across multiple function calls.
224        See :term:`Glossary <random_state>`.
225
226    return_X_mean : bool, default=False
227        If True, X_mean is returned too.
228
229    compute_sources : bool, default=True
230        If False, sources are not computed, but only the rotation matrix.
231        This can save memory when working with big data. Defaults to True.
232
233    return_n_iter : bool, default=False
234        Whether or not to return the number of iterations.
235
236    Returns
237    -------
238    K : ndarray of shape (n_components, n_features) or None
239        If whiten is 'True', K is the pre-whitening matrix that projects data
240        onto the first n_components principal components. If whiten is 'False',
241        K is 'None'.
242
243    W : ndarray of shape (n_components, n_components)
244        The square matrix that unmixes the data after whitening.
245        The mixing matrix is the pseudo-inverse of matrix ``W K``
246        if K is not None, else it is the inverse of W.
247
248    S : ndarray of shape (n_samples, n_components) or None
249        Estimated source matrix
250
251    X_mean : ndarray of shape (n_features,)
252        The mean over features. Returned only if return_X_mean is True.
253
254    n_iter : int
255        If the algorithm is "deflation", n_iter is the
256        maximum number of iterations run across all components. Else
257        they are just the number of iterations taken to converge. This is
258        returned only when return_n_iter is set to `True`.
259
260    Notes
261    -----
262    The data matrix X is considered to be a linear combination of
263    non-Gaussian (independent) components i.e. X = AS where columns of S
264    contain the independent components and A is a linear mixing
265    matrix. In short ICA attempts to `un-mix' the data by estimating an
266    un-mixing matrix W where ``S = W K X.``
267    While FastICA was proposed to estimate as many sources
268    as features, it is possible to estimate less by setting
269    n_components < n_features. It this case K is not a square matrix
270    and the estimated A is the pseudo-inverse of ``W K``.
271
272    This implementation was originally made for data of shape
273    [n_features, n_samples]. Now the input is transposed
274    before the algorithm is applied. This makes it slightly
275    faster for Fortran-ordered input.
276
277    References
278    ----------
279    .. [1] A. Hyvarinen and E. Oja, "Fast Independent Component Analysis",
280           Algorithms and Applications, Neural Networks, 13(4-5), 2000,
281           pp. 411-430.
282    """
283
284    est = FastICA(
285        n_components=n_components,
286        algorithm=algorithm,
287        whiten=whiten,
288        fun=fun,
289        fun_args=fun_args,
290        max_iter=max_iter,
291        tol=tol,
292        w_init=w_init,
293        random_state=random_state,
294    )
295    sources = est._fit(X, compute_sources=compute_sources)
296
297    if whiten:
298        if return_X_mean:
299            if return_n_iter:
300                return (est.whitening_, est._unmixing, sources, est.mean_, est.n_iter_)
301            else:
302                return est.whitening_, est._unmixing, sources, est.mean_
303        else:
304            if return_n_iter:
305                return est.whitening_, est._unmixing, sources, est.n_iter_
306            else:
307                return est.whitening_, est._unmixing, sources
308
309    else:
310        if return_X_mean:
311            if return_n_iter:
312                return None, est._unmixing, sources, None, est.n_iter_
313            else:
314                return None, est._unmixing, sources, None
315        else:
316            if return_n_iter:
317                return None, est._unmixing, sources, est.n_iter_
318            else:
319                return None, est._unmixing, sources
320
321
322class FastICA(TransformerMixin, BaseEstimator):
323    """FastICA: a fast algorithm for Independent Component Analysis.
324
325    The implementation is based on [1]_.
326
327    Read more in the :ref:`User Guide <ICA>`.
328
329    Parameters
330    ----------
331    n_components : int, default=None
332        Number of components to use. If None is passed, all are used.
333
334    algorithm : {'parallel', 'deflation'}, default='parallel'
335        Apply parallel or deflational algorithm for FastICA.
336
337    whiten : bool, default=True
338        If whiten is false, the data is already considered to be
339        whitened, and no whitening is performed.
340
341    fun : {'logcosh', 'exp', 'cube'} or callable, default='logcosh'
342        The functional form of the G function used in the
343        approximation to neg-entropy. Could be either 'logcosh', 'exp',
344        or 'cube'.
345        You can also provide your own function. It should return a tuple
346        containing the value of the function, and of its derivative, in the
347        point. Example::
348
349            def my_g(x):
350                return x ** 3, (3 * x ** 2).mean(axis=-1)
351
352    fun_args : dict, default=None
353        Arguments to send to the functional form.
354        If empty and if fun='logcosh', fun_args will take value
355        {'alpha' : 1.0}.
356
357    max_iter : int, default=200
358        Maximum number of iterations during fit.
359
360    tol : float, default=1e-4
361        Tolerance on update at each iteration.
362
363    w_init : ndarray of shape (n_components, n_components), default=None
364        The mixing matrix to be used to initialize the algorithm.
365
366    random_state : int, RandomState instance or None, default=None
367        Used to initialize ``w_init`` when not specified, with a
368        normal distribution. Pass an int, for reproducible results
369        across multiple function calls.
370        See :term:`Glossary <random_state>`.
371
372    Attributes
373    ----------
374    components_ : ndarray of shape (n_components, n_features)
375        The linear operator to apply to the data to get the independent
376        sources. This is equal to the unmixing matrix when ``whiten`` is
377        False, and equal to ``np.dot(unmixing_matrix, self.whitening_)`` when
378        ``whiten`` is True.
379
380    mixing_ : ndarray of shape (n_features, n_components)
381        The pseudo-inverse of ``components_``. It is the linear operator
382        that maps independent sources to the data.
383
384    mean_ : ndarray of shape(n_features,)
385        The mean over features. Only set if `self.whiten` is True.
386
387    n_features_in_ : int
388        Number of features seen during :term:`fit`.
389
390        .. versionadded:: 0.24
391
392    feature_names_in_ : ndarray of shape (`n_features_in_`,)
393        Names of features seen during :term:`fit`. Defined only when `X`
394        has feature names that are all strings.
395
396        .. versionadded:: 1.0
397
398    n_iter_ : int
399        If the algorithm is "deflation", n_iter is the
400        maximum number of iterations run across all components. Else
401        they are just the number of iterations taken to converge.
402
403    whitening_ : ndarray of shape (n_components, n_features)
404        Only set if whiten is 'True'. This is the pre-whitening matrix
405        that projects data onto the first `n_components` principal components.
406
407    See Also
408    --------
409    PCA : Principal component analysis (PCA).
410    IncrementalPCA : Incremental principal components analysis (IPCA).
411    KernelPCA : Kernel Principal component analysis (KPCA).
412    MiniBatchSparsePCA : Mini-batch Sparse Principal Components Analysis.
413    SparsePCA : Sparse Principal Components Analysis (SparsePCA).
414
415    References
416    ----------
417    .. [1] A. Hyvarinen and E. Oja, Independent Component Analysis:
418           Algorithms and Applications, Neural Networks, 13(4-5), 2000,
419           pp. 411-430.
420
421    Examples
422    --------
423    >>> from sklearn.datasets import load_digits
424    >>> from sklearn.decomposition import FastICA
425    >>> X, _ = load_digits(return_X_y=True)
426    >>> transformer = FastICA(n_components=7,
427    ...         random_state=0)
428    >>> X_transformed = transformer.fit_transform(X)
429    >>> X_transformed.shape
430    (1797, 7)
431    """
432
433    def __init__(
434        self,
435        n_components=None,
436        *,
437        algorithm="parallel",
438        whiten=True,
439        fun="logcosh",
440        fun_args=None,
441        max_iter=200,
442        tol=1e-4,
443        w_init=None,
444        random_state=None,
445    ):
446        super().__init__()
447        if max_iter < 1:
448            raise ValueError(
449                "max_iter should be greater than 1, got (max_iter={})".format(max_iter)
450            )
451        self.n_components = n_components
452        self.algorithm = algorithm
453        self.whiten = whiten
454        self.fun = fun
455        self.fun_args = fun_args
456        self.max_iter = max_iter
457        self.tol = tol
458        self.w_init = w_init
459        self.random_state = random_state
460
461    def _fit(self, X, compute_sources=False):
462        """Fit the model
463
464        Parameters
465        ----------
466        X : array-like of shape (n_samples, n_features)
467            Training data, where `n_samples` is the number of samples
468            and `n_features` is the number of features.
469
470        compute_sources : bool, default=False
471            If False, sources are not computes but only the rotation matrix.
472            This can save memory when working with big data. Defaults to False.
473
474        Returns
475        -------
476        S : ndarray of shape (n_samples, n_components) or None
477            Sources matrix. `None` if `compute_sources` is `False`.
478        """
479        XT = self._validate_data(
480            X, copy=self.whiten, dtype=FLOAT_DTYPES, ensure_min_samples=2
481        ).T
482        fun_args = {} if self.fun_args is None else self.fun_args
483        random_state = check_random_state(self.random_state)
484
485        alpha = fun_args.get("alpha", 1.0)
486        if not 1 <= alpha <= 2:
487            raise ValueError("alpha must be in [1,2]")
488
489        if self.fun == "logcosh":
490            g = _logcosh
491        elif self.fun == "exp":
492            g = _exp
493        elif self.fun == "cube":
494            g = _cube
495        elif callable(self.fun):
496
497            def g(x, fun_args):
498                return self.fun(x, **fun_args)
499
500        else:
501            exc = ValueError if isinstance(self.fun, str) else TypeError
502            raise exc(
503                "Unknown function %r;"
504                " should be one of 'logcosh', 'exp', 'cube' or callable"
505                % self.fun
506            )
507
508        n_features, n_samples = XT.shape
509
510        n_components = self.n_components
511        if not self.whiten and n_components is not None:
512            n_components = None
513            warnings.warn("Ignoring n_components with whiten=False.")
514
515        if n_components is None:
516            n_components = min(n_samples, n_features)
517        if n_components > min(n_samples, n_features):
518            n_components = min(n_samples, n_features)
519            warnings.warn(
520                "n_components is too large: it will be set to %s" % n_components
521            )
522
523        if self.whiten:
524            # Centering the features of X
525            X_mean = XT.mean(axis=-1)
526            XT -= X_mean[:, np.newaxis]
527
528            # Whitening and preprocessing by PCA
529            u, d, _ = linalg.svd(XT, full_matrices=False, check_finite=False)
530
531            del _
532            K = (u / d).T[:n_components]  # see (6.33) p.140
533            del u, d
534            X1 = np.dot(K, XT)
535            # see (13.6) p.267 Here X1 is white and data
536            # in X has been projected onto a subspace by PCA
537            X1 *= np.sqrt(n_samples)
538        else:
539            # X must be casted to floats to avoid typing issues with numpy
540            # 2.0 and the line below
541            X1 = as_float_array(XT, copy=False)  # copy has been taken care of
542
543        w_init = self.w_init
544        if w_init is None:
545            w_init = np.asarray(
546                random_state.normal(size=(n_components, n_components)), dtype=X1.dtype
547            )
548
549        else:
550            w_init = np.asarray(w_init)
551            if w_init.shape != (n_components, n_components):
552                raise ValueError(
553                    "w_init has invalid shape -- should be %(shape)s"
554                    % {"shape": (n_components, n_components)}
555                )
556
557        kwargs = {
558            "tol": self.tol,
559            "g": g,
560            "fun_args": fun_args,
561            "max_iter": self.max_iter,
562            "w_init": w_init,
563        }
564
565        if self.algorithm == "parallel":
566            W, n_iter = _ica_par(X1, **kwargs)
567        elif self.algorithm == "deflation":
568            W, n_iter = _ica_def(X1, **kwargs)
569        else:
570            raise ValueError(
571                "Invalid algorithm: must be either `parallel` or `deflation`."
572            )
573        del X1
574
575        if compute_sources:
576            if self.whiten:
577                S = np.linalg.multi_dot([W, K, XT]).T
578            else:
579                S = np.dot(W, XT).T
580        else:
581            S = None
582
583        self.n_iter_ = n_iter
584
585        if self.whiten:
586            self.components_ = np.dot(W, K)
587            self.mean_ = X_mean
588            self.whitening_ = K
589        else:
590            self.components_ = W
591
592        self.mixing_ = linalg.pinv(self.components_, check_finite=False)
593        self._unmixing = W
594
595        return S
596
597    def fit_transform(self, X, y=None):
598        """Fit the model and recover the sources from X.
599
600        Parameters
601        ----------
602        X : array-like of shape (n_samples, n_features)
603            Training data, where `n_samples` is the number of samples
604            and `n_features` is the number of features.
605
606        y : Ignored
607            Not used, present for API consistency by convention.
608
609        Returns
610        -------
611        X_new : ndarray of shape (n_samples, n_components)
612            Estimated sources obtained by transforming the data with the
613            estimated unmixing matrix.
614        """
615        return self._fit(X, compute_sources=True)
616
617    def fit(self, X, y=None):
618        """Fit the model to X.
619
620        Parameters
621        ----------
622        X : array-like of shape (n_samples, n_features)
623            Training data, where `n_samples` is the number of samples
624            and `n_features` is the number of features.
625
626        y : Ignored
627            Not used, present for API consistency by convention.
628
629        Returns
630        -------
631        self : object
632            Returns the instance itself.
633        """
634        self._fit(X, compute_sources=False)
635        return self
636
637    def transform(self, X, copy=True):
638        """Recover the sources from X (apply the unmixing matrix).
639
640        Parameters
641        ----------
642        X : array-like of shape (n_samples, n_features)
643            Data to transform, where `n_samples` is the number of samples
644            and `n_features` is the number of features.
645
646        copy : bool, default=True
647            If False, data passed to fit can be overwritten. Defaults to True.
648
649        Returns
650        -------
651        X_new : ndarray of shape (n_samples, n_components)
652            Estimated sources obtained by transforming the data with the
653            estimated unmixing matrix.
654        """
655        check_is_fitted(self)
656
657        X = self._validate_data(
658            X, copy=(copy and self.whiten), dtype=FLOAT_DTYPES, reset=False
659        )
660        if self.whiten:
661            X -= self.mean_
662
663        return np.dot(X, self.components_.T)
664
665    def inverse_transform(self, X, copy=True):
666        """Transform the sources back to the mixed data (apply mixing matrix).
667
668        Parameters
669        ----------
670        X : array-like of shape (n_samples, n_components)
671            Sources, where `n_samples` is the number of samples
672            and `n_components` is the number of components.
673        copy : bool, default=True
674            If False, data passed to fit are overwritten. Defaults to True.
675
676        Returns
677        -------
678        X_new : ndarray of shape (n_samples, n_features)
679            Reconstructed data obtained with the mixing matrix.
680        """
681        check_is_fitted(self)
682
683        X = check_array(X, copy=(copy and self.whiten), dtype=FLOAT_DTYPES)
684        X = np.dot(X, self.mixing_.T)
685        if self.whiten:
686            X += self.mean_
687
688        return X
689