1"""Orthogonal matching pursuit algorithms
2"""
3
4# Author: Vlad Niculae
5#
6# License: BSD 3 clause
7
8import warnings
9from math import sqrt
10
11import numpy as np
12from scipy import linalg
13from scipy.linalg.lapack import get_lapack_funcs
14from joblib import Parallel
15
16from ._base import LinearModel, _pre_fit, _deprecate_normalize
17from ..base import RegressorMixin, MultiOutputMixin
18from ..utils import as_float_array, check_array
19from ..utils.fixes import delayed
20from ..model_selection import check_cv
21
22premature = (
23    "Orthogonal matching pursuit ended prematurely due to linear"
24    " dependence in the dictionary. The requested precision might"
25    " not have been met."
26)
27
28
29def _cholesky_omp(X, y, n_nonzero_coefs, tol=None, copy_X=True, return_path=False):
30    """Orthogonal Matching Pursuit step using the Cholesky decomposition.
31
32    Parameters
33    ----------
34    X : ndarray of shape (n_samples, n_features)
35        Input dictionary. Columns are assumed to have unit norm.
36
37    y : ndarray of shape (n_samples,)
38        Input targets.
39
40    n_nonzero_coefs : int
41        Targeted number of non-zero elements.
42
43    tol : float, default=None
44        Targeted squared error, if not None overrides n_nonzero_coefs.
45
46    copy_X : bool, default=True
47        Whether the design matrix X must be copied by the algorithm. A false
48        value is only helpful if X is already Fortran-ordered, otherwise a
49        copy is made anyway.
50
51    return_path : bool, default=False
52        Whether to return every value of the nonzero coefficients along the
53        forward path. Useful for cross-validation.
54
55    Returns
56    -------
57    gamma : ndarray of shape (n_nonzero_coefs,)
58        Non-zero elements of the solution.
59
60    idx : ndarray of shape (n_nonzero_coefs,)
61        Indices of the positions of the elements in gamma within the solution
62        vector.
63
64    coef : ndarray of shape (n_features, n_nonzero_coefs)
65        The first k values of column k correspond to the coefficient value
66        for the active features at that step. The lower left triangle contains
67        garbage. Only returned if ``return_path=True``.
68
69    n_active : int
70        Number of active features at convergence.
71    """
72    if copy_X:
73        X = X.copy("F")
74    else:  # even if we are allowed to overwrite, still copy it if bad order
75        X = np.asfortranarray(X)
76
77    min_float = np.finfo(X.dtype).eps
78    nrm2, swap = linalg.get_blas_funcs(("nrm2", "swap"), (X,))
79    (potrs,) = get_lapack_funcs(("potrs",), (X,))
80
81    alpha = np.dot(X.T, y)
82    residual = y
83    gamma = np.empty(0)
84    n_active = 0
85    indices = np.arange(X.shape[1])  # keeping track of swapping
86
87    max_features = X.shape[1] if tol is not None else n_nonzero_coefs
88
89    L = np.empty((max_features, max_features), dtype=X.dtype)
90
91    if return_path:
92        coefs = np.empty_like(L)
93
94    while True:
95        lam = np.argmax(np.abs(np.dot(X.T, residual)))
96        if lam < n_active or alpha[lam] ** 2 < min_float:
97            # atom already selected or inner product too small
98            warnings.warn(premature, RuntimeWarning, stacklevel=2)
99            break
100
101        if n_active > 0:
102            # Updates the Cholesky decomposition of X' X
103            L[n_active, :n_active] = np.dot(X[:, :n_active].T, X[:, lam])
104            linalg.solve_triangular(
105                L[:n_active, :n_active],
106                L[n_active, :n_active],
107                trans=0,
108                lower=1,
109                overwrite_b=True,
110                check_finite=False,
111            )
112            v = nrm2(L[n_active, :n_active]) ** 2
113            Lkk = linalg.norm(X[:, lam]) ** 2 - v
114            if Lkk <= min_float:  # selected atoms are dependent
115                warnings.warn(premature, RuntimeWarning, stacklevel=2)
116                break
117            L[n_active, n_active] = sqrt(Lkk)
118        else:
119            L[0, 0] = linalg.norm(X[:, lam])
120
121        X.T[n_active], X.T[lam] = swap(X.T[n_active], X.T[lam])
122        alpha[n_active], alpha[lam] = alpha[lam], alpha[n_active]
123        indices[n_active], indices[lam] = indices[lam], indices[n_active]
124        n_active += 1
125
126        # solves LL'x = X'y as a composition of two triangular systems
127        gamma, _ = potrs(
128            L[:n_active, :n_active], alpha[:n_active], lower=True, overwrite_b=False
129        )
130
131        if return_path:
132            coefs[:n_active, n_active - 1] = gamma
133        residual = y - np.dot(X[:, :n_active], gamma)
134        if tol is not None and nrm2(residual) ** 2 <= tol:
135            break
136        elif n_active == max_features:
137            break
138
139    if return_path:
140        return gamma, indices[:n_active], coefs[:, :n_active], n_active
141    else:
142        return gamma, indices[:n_active], n_active
143
144
145def _gram_omp(
146    Gram,
147    Xy,
148    n_nonzero_coefs,
149    tol_0=None,
150    tol=None,
151    copy_Gram=True,
152    copy_Xy=True,
153    return_path=False,
154):
155    """Orthogonal Matching Pursuit step on a precomputed Gram matrix.
156
157    This function uses the Cholesky decomposition method.
158
159    Parameters
160    ----------
161    Gram : ndarray of shape (n_features, n_features)
162        Gram matrix of the input data matrix.
163
164    Xy : ndarray of shape (n_features,)
165        Input targets.
166
167    n_nonzero_coefs : int
168        Targeted number of non-zero elements.
169
170    tol_0 : float, default=None
171        Squared norm of y, required if tol is not None.
172
173    tol : float, default=None
174        Targeted squared error, if not None overrides n_nonzero_coefs.
175
176    copy_Gram : bool, default=True
177        Whether the gram matrix must be copied by the algorithm. A false
178        value is only helpful if it is already Fortran-ordered, otherwise a
179        copy is made anyway.
180
181    copy_Xy : bool, default=True
182        Whether the covariance vector Xy must be copied by the algorithm.
183        If False, it may be overwritten.
184
185    return_path : bool, default=False
186        Whether to return every value of the nonzero coefficients along the
187        forward path. Useful for cross-validation.
188
189    Returns
190    -------
191    gamma : ndarray of shape (n_nonzero_coefs,)
192        Non-zero elements of the solution.
193
194    idx : ndarray of shape (n_nonzero_coefs,)
195        Indices of the positions of the elements in gamma within the solution
196        vector.
197
198    coefs : ndarray of shape (n_features, n_nonzero_coefs)
199        The first k values of column k correspond to the coefficient value
200        for the active features at that step. The lower left triangle contains
201        garbage. Only returned if ``return_path=True``.
202
203    n_active : int
204        Number of active features at convergence.
205    """
206    Gram = Gram.copy("F") if copy_Gram else np.asfortranarray(Gram)
207
208    if copy_Xy or not Xy.flags.writeable:
209        Xy = Xy.copy()
210
211    min_float = np.finfo(Gram.dtype).eps
212    nrm2, swap = linalg.get_blas_funcs(("nrm2", "swap"), (Gram,))
213    (potrs,) = get_lapack_funcs(("potrs",), (Gram,))
214
215    indices = np.arange(len(Gram))  # keeping track of swapping
216    alpha = Xy
217    tol_curr = tol_0
218    delta = 0
219    gamma = np.empty(0)
220    n_active = 0
221
222    max_features = len(Gram) if tol is not None else n_nonzero_coefs
223
224    L = np.empty((max_features, max_features), dtype=Gram.dtype)
225
226    L[0, 0] = 1.0
227    if return_path:
228        coefs = np.empty_like(L)
229
230    while True:
231        lam = np.argmax(np.abs(alpha))
232        if lam < n_active or alpha[lam] ** 2 < min_float:
233            # selected same atom twice, or inner product too small
234            warnings.warn(premature, RuntimeWarning, stacklevel=3)
235            break
236        if n_active > 0:
237            L[n_active, :n_active] = Gram[lam, :n_active]
238            linalg.solve_triangular(
239                L[:n_active, :n_active],
240                L[n_active, :n_active],
241                trans=0,
242                lower=1,
243                overwrite_b=True,
244                check_finite=False,
245            )
246            v = nrm2(L[n_active, :n_active]) ** 2
247            Lkk = Gram[lam, lam] - v
248            if Lkk <= min_float:  # selected atoms are dependent
249                warnings.warn(premature, RuntimeWarning, stacklevel=3)
250                break
251            L[n_active, n_active] = sqrt(Lkk)
252        else:
253            L[0, 0] = sqrt(Gram[lam, lam])
254
255        Gram[n_active], Gram[lam] = swap(Gram[n_active], Gram[lam])
256        Gram.T[n_active], Gram.T[lam] = swap(Gram.T[n_active], Gram.T[lam])
257        indices[n_active], indices[lam] = indices[lam], indices[n_active]
258        Xy[n_active], Xy[lam] = Xy[lam], Xy[n_active]
259        n_active += 1
260        # solves LL'x = X'y as a composition of two triangular systems
261        gamma, _ = potrs(
262            L[:n_active, :n_active], Xy[:n_active], lower=True, overwrite_b=False
263        )
264        if return_path:
265            coefs[:n_active, n_active - 1] = gamma
266        beta = np.dot(Gram[:, :n_active], gamma)
267        alpha = Xy - beta
268        if tol is not None:
269            tol_curr += delta
270            delta = np.inner(gamma, beta[:n_active])
271            tol_curr -= delta
272            if abs(tol_curr) <= tol:
273                break
274        elif n_active == max_features:
275            break
276
277    if return_path:
278        return gamma, indices[:n_active], coefs[:, :n_active], n_active
279    else:
280        return gamma, indices[:n_active], n_active
281
282
283def orthogonal_mp(
284    X,
285    y,
286    *,
287    n_nonzero_coefs=None,
288    tol=None,
289    precompute=False,
290    copy_X=True,
291    return_path=False,
292    return_n_iter=False,
293):
294    r"""Orthogonal Matching Pursuit (OMP).
295
296    Solves n_targets Orthogonal Matching Pursuit problems.
297    An instance of the problem has the form:
298
299    When parametrized by the number of non-zero coefficients using
300    `n_nonzero_coefs`:
301    argmin ||y - X\gamma||^2 subject to ||\gamma||_0 <= n_{nonzero coefs}
302
303    When parametrized by error using the parameter `tol`:
304    argmin ||\gamma||_0 subject to ||y - X\gamma||^2 <= tol
305
306    Read more in the :ref:`User Guide <omp>`.
307
308    Parameters
309    ----------
310    X : ndarray of shape (n_samples, n_features)
311        Input data. Columns are assumed to have unit norm.
312
313    y : ndarray of shape (n_samples,) or (n_samples, n_targets)
314        Input targets.
315
316    n_nonzero_coefs : int, default=None
317        Desired number of non-zero entries in the solution. If None (by
318        default) this value is set to 10% of n_features.
319
320    tol : float, default=None
321        Maximum norm of the residual. If not None, overrides n_nonzero_coefs.
322
323    precompute : 'auto' or bool, default=False
324        Whether to perform precomputations. Improves performance when n_targets
325        or n_samples is very large.
326
327    copy_X : bool, default=True
328        Whether the design matrix X must be copied by the algorithm. A false
329        value is only helpful if X is already Fortran-ordered, otherwise a
330        copy is made anyway.
331
332    return_path : bool, default=False
333        Whether to return every value of the nonzero coefficients along the
334        forward path. Useful for cross-validation.
335
336    return_n_iter : bool, default=False
337        Whether or not to return the number of iterations.
338
339    Returns
340    -------
341    coef : ndarray of shape (n_features,) or (n_features, n_targets)
342        Coefficients of the OMP solution. If `return_path=True`, this contains
343        the whole coefficient path. In this case its shape is
344        (n_features, n_features) or (n_features, n_targets, n_features) and
345        iterating over the last axis yields coefficients in increasing order
346        of active features.
347
348    n_iters : array-like or int
349        Number of active features across every target. Returned only if
350        `return_n_iter` is set to True.
351
352    See Also
353    --------
354    OrthogonalMatchingPursuit
355    orthogonal_mp_gram
356    lars_path
357    sklearn.decomposition.sparse_encode
358
359    Notes
360    -----
361    Orthogonal matching pursuit was introduced in S. Mallat, Z. Zhang,
362    Matching pursuits with time-frequency dictionaries, IEEE Transactions on
363    Signal Processing, Vol. 41, No. 12. (December 1993), pp. 3397-3415.
364    (http://blanche.polytechnique.fr/~mallat/papiers/MallatPursuit93.pdf)
365
366    This implementation is based on Rubinstein, R., Zibulevsky, M. and Elad,
367    M., Efficient Implementation of the K-SVD Algorithm using Batch Orthogonal
368    Matching Pursuit Technical Report - CS Technion, April 2008.
369    https://www.cs.technion.ac.il/~ronrubin/Publications/KSVD-OMP-v2.pdf
370
371    """
372    X = check_array(X, order="F", copy=copy_X)
373    copy_X = False
374    if y.ndim == 1:
375        y = y.reshape(-1, 1)
376    y = check_array(y)
377    if y.shape[1] > 1:  # subsequent targets will be affected
378        copy_X = True
379    if n_nonzero_coefs is None and tol is None:
380        # default for n_nonzero_coefs is 0.1 * n_features
381        # but at least one.
382        n_nonzero_coefs = max(int(0.1 * X.shape[1]), 1)
383    if tol is not None and tol < 0:
384        raise ValueError("Epsilon cannot be negative")
385    if tol is None and n_nonzero_coefs <= 0:
386        raise ValueError("The number of atoms must be positive")
387    if tol is None and n_nonzero_coefs > X.shape[1]:
388        raise ValueError(
389            "The number of atoms cannot be more than the number of features"
390        )
391    if precompute == "auto":
392        precompute = X.shape[0] > X.shape[1]
393    if precompute:
394        G = np.dot(X.T, X)
395        G = np.asfortranarray(G)
396        Xy = np.dot(X.T, y)
397        if tol is not None:
398            norms_squared = np.sum((y ** 2), axis=0)
399        else:
400            norms_squared = None
401        return orthogonal_mp_gram(
402            G,
403            Xy,
404            n_nonzero_coefs=n_nonzero_coefs,
405            tol=tol,
406            norms_squared=norms_squared,
407            copy_Gram=copy_X,
408            copy_Xy=False,
409            return_path=return_path,
410        )
411
412    if return_path:
413        coef = np.zeros((X.shape[1], y.shape[1], X.shape[1]))
414    else:
415        coef = np.zeros((X.shape[1], y.shape[1]))
416    n_iters = []
417
418    for k in range(y.shape[1]):
419        out = _cholesky_omp(
420            X, y[:, k], n_nonzero_coefs, tol, copy_X=copy_X, return_path=return_path
421        )
422        if return_path:
423            _, idx, coefs, n_iter = out
424            coef = coef[:, :, : len(idx)]
425            for n_active, x in enumerate(coefs.T):
426                coef[idx[: n_active + 1], k, n_active] = x[: n_active + 1]
427        else:
428            x, idx, n_iter = out
429            coef[idx, k] = x
430        n_iters.append(n_iter)
431
432    if y.shape[1] == 1:
433        n_iters = n_iters[0]
434
435    if return_n_iter:
436        return np.squeeze(coef), n_iters
437    else:
438        return np.squeeze(coef)
439
440
441def orthogonal_mp_gram(
442    Gram,
443    Xy,
444    *,
445    n_nonzero_coefs=None,
446    tol=None,
447    norms_squared=None,
448    copy_Gram=True,
449    copy_Xy=True,
450    return_path=False,
451    return_n_iter=False,
452):
453    """Gram Orthogonal Matching Pursuit (OMP).
454
455    Solves n_targets Orthogonal Matching Pursuit problems using only
456    the Gram matrix X.T * X and the product X.T * y.
457
458    Read more in the :ref:`User Guide <omp>`.
459
460    Parameters
461    ----------
462    Gram : ndarray of shape (n_features, n_features)
463        Gram matrix of the input data: X.T * X.
464
465    Xy : ndarray of shape (n_features,) or (n_features, n_targets)
466        Input targets multiplied by X: X.T * y.
467
468    n_nonzero_coefs : int, default=None
469        Desired number of non-zero entries in the solution. If None (by
470        default) this value is set to 10% of n_features.
471
472    tol : float, default=None
473        Maximum norm of the residual. If not None, overrides n_nonzero_coefs.
474
475    norms_squared : array-like of shape (n_targets,), default=None
476        Squared L2 norms of the lines of y. Required if tol is not None.
477
478    copy_Gram : bool, default=True
479        Whether the gram matrix must be copied by the algorithm. A false
480        value is only helpful if it is already Fortran-ordered, otherwise a
481        copy is made anyway.
482
483    copy_Xy : bool, default=True
484        Whether the covariance vector Xy must be copied by the algorithm.
485        If False, it may be overwritten.
486
487    return_path : bool, default=False
488        Whether to return every value of the nonzero coefficients along the
489        forward path. Useful for cross-validation.
490
491    return_n_iter : bool, default=False
492        Whether or not to return the number of iterations.
493
494    Returns
495    -------
496    coef : ndarray of shape (n_features,) or (n_features, n_targets)
497        Coefficients of the OMP solution. If `return_path=True`, this contains
498        the whole coefficient path. In this case its shape is
499        (n_features, n_features) or (n_features, n_targets, n_features) and
500        iterating over the last axis yields coefficients in increasing order
501        of active features.
502
503    n_iters : array-like or int
504        Number of active features across every target. Returned only if
505        `return_n_iter` is set to True.
506
507    See Also
508    --------
509    OrthogonalMatchingPursuit
510    orthogonal_mp
511    lars_path
512    sklearn.decomposition.sparse_encode
513
514    Notes
515    -----
516    Orthogonal matching pursuit was introduced in G. Mallat, Z. Zhang,
517    Matching pursuits with time-frequency dictionaries, IEEE Transactions on
518    Signal Processing, Vol. 41, No. 12. (December 1993), pp. 3397-3415.
519    (http://blanche.polytechnique.fr/~mallat/papiers/MallatPursuit93.pdf)
520
521    This implementation is based on Rubinstein, R., Zibulevsky, M. and Elad,
522    M., Efficient Implementation of the K-SVD Algorithm using Batch Orthogonal
523    Matching Pursuit Technical Report - CS Technion, April 2008.
524    https://www.cs.technion.ac.il/~ronrubin/Publications/KSVD-OMP-v2.pdf
525
526    """
527    Gram = check_array(Gram, order="F", copy=copy_Gram)
528    Xy = np.asarray(Xy)
529    if Xy.ndim > 1 and Xy.shape[1] > 1:
530        # or subsequent target will be affected
531        copy_Gram = True
532    if Xy.ndim == 1:
533        Xy = Xy[:, np.newaxis]
534        if tol is not None:
535            norms_squared = [norms_squared]
536    if copy_Xy or not Xy.flags.writeable:
537        # Make the copy once instead of many times in _gram_omp itself.
538        Xy = Xy.copy()
539
540    if n_nonzero_coefs is None and tol is None:
541        n_nonzero_coefs = int(0.1 * len(Gram))
542    if tol is not None and norms_squared is None:
543        raise ValueError(
544            "Gram OMP needs the precomputed norms in order "
545            "to evaluate the error sum of squares."
546        )
547    if tol is not None and tol < 0:
548        raise ValueError("Epsilon cannot be negative")
549    if tol is None and n_nonzero_coefs <= 0:
550        raise ValueError("The number of atoms must be positive")
551    if tol is None and n_nonzero_coefs > len(Gram):
552        raise ValueError(
553            "The number of atoms cannot be more than the number of features"
554        )
555
556    if return_path:
557        coef = np.zeros((len(Gram), Xy.shape[1], len(Gram)))
558    else:
559        coef = np.zeros((len(Gram), Xy.shape[1]))
560
561    n_iters = []
562    for k in range(Xy.shape[1]):
563        out = _gram_omp(
564            Gram,
565            Xy[:, k],
566            n_nonzero_coefs,
567            norms_squared[k] if tol is not None else None,
568            tol,
569            copy_Gram=copy_Gram,
570            copy_Xy=False,
571            return_path=return_path,
572        )
573        if return_path:
574            _, idx, coefs, n_iter = out
575            coef = coef[:, :, : len(idx)]
576            for n_active, x in enumerate(coefs.T):
577                coef[idx[: n_active + 1], k, n_active] = x[: n_active + 1]
578        else:
579            x, idx, n_iter = out
580            coef[idx, k] = x
581        n_iters.append(n_iter)
582
583    if Xy.shape[1] == 1:
584        n_iters = n_iters[0]
585
586    if return_n_iter:
587        return np.squeeze(coef), n_iters
588    else:
589        return np.squeeze(coef)
590
591
592class OrthogonalMatchingPursuit(MultiOutputMixin, RegressorMixin, LinearModel):
593    """Orthogonal Matching Pursuit model (OMP).
594
595    Read more in the :ref:`User Guide <omp>`.
596
597    Parameters
598    ----------
599    n_nonzero_coefs : int, default=None
600        Desired number of non-zero entries in the solution. If None (by
601        default) this value is set to 10% of n_features.
602
603    tol : float, default=None
604        Maximum norm of the residual. If not None, overrides n_nonzero_coefs.
605
606    fit_intercept : bool, default=True
607        Whether to calculate the intercept for this model. If set
608        to false, no intercept will be used in calculations
609        (i.e. data is expected to be centered).
610
611    normalize : bool, default=True
612        This parameter is ignored when ``fit_intercept`` is set to False.
613        If True, the regressors X will be normalized before regression by
614        subtracting the mean and dividing by the l2-norm.
615        If you wish to standardize, please use
616        :class:`~sklearn.preprocessing.StandardScaler` before calling ``fit``
617        on an estimator with ``normalize=False``.
618
619        .. deprecated:: 1.0
620            ``normalize`` was deprecated in version 1.0. It will default
621            to False in 1.2 and be removed in 1.4.
622
623    precompute : 'auto' or bool, default='auto'
624        Whether to use a precomputed Gram and Xy matrix to speed up
625        calculations. Improves performance when :term:`n_targets` or
626        :term:`n_samples` is very large. Note that if you already have such
627        matrices, you can pass them directly to the fit method.
628
629    Attributes
630    ----------
631    coef_ : ndarray of shape (n_features,) or (n_targets, n_features)
632        Parameter vector (w in the formula).
633
634    intercept_ : float or ndarray of shape (n_targets,)
635        Independent term in decision function.
636
637    n_iter_ : int or array-like
638        Number of active features across every target.
639
640    n_nonzero_coefs_ : int
641        The number of non-zero coefficients in the solution. If
642        `n_nonzero_coefs` is None and `tol` is None this value is either set
643        to 10% of `n_features` or 1, whichever is greater.
644
645    n_features_in_ : int
646        Number of features seen during :term:`fit`.
647
648        .. versionadded:: 0.24
649
650    feature_names_in_ : ndarray of shape (`n_features_in_`,)
651        Names of features seen during :term:`fit`. Defined only when `X`
652        has feature names that are all strings.
653
654        .. versionadded:: 1.0
655
656    See Also
657    --------
658    orthogonal_mp : Solves n_targets Orthogonal Matching Pursuit problems.
659    orthogonal_mp_gram :  Solves n_targets Orthogonal Matching Pursuit
660        problems using only the Gram matrix X.T * X and the product X.T * y.
661    lars_path : Compute Least Angle Regression or Lasso path using LARS algorithm.
662    Lars : Least Angle Regression model a.k.a. LAR.
663    LassoLars : Lasso model fit with Least Angle Regression a.k.a. Lars.
664    sklearn.decomposition.sparse_encode : Generic sparse coding.
665        Each column of the result is the solution to a Lasso problem.
666    OrthogonalMatchingPursuitCV : Cross-validated
667        Orthogonal Matching Pursuit model (OMP).
668
669    Notes
670    -----
671    Orthogonal matching pursuit was introduced in G. Mallat, Z. Zhang,
672    Matching pursuits with time-frequency dictionaries, IEEE Transactions on
673    Signal Processing, Vol. 41, No. 12. (December 1993), pp. 3397-3415.
674    (http://blanche.polytechnique.fr/~mallat/papiers/MallatPursuit93.pdf)
675
676    This implementation is based on Rubinstein, R., Zibulevsky, M. and Elad,
677    M., Efficient Implementation of the K-SVD Algorithm using Batch Orthogonal
678    Matching Pursuit Technical Report - CS Technion, April 2008.
679    https://www.cs.technion.ac.il/~ronrubin/Publications/KSVD-OMP-v2.pdf
680
681    Examples
682    --------
683    >>> from sklearn.linear_model import OrthogonalMatchingPursuit
684    >>> from sklearn.datasets import make_regression
685    >>> X, y = make_regression(noise=4, random_state=0)
686    >>> reg = OrthogonalMatchingPursuit(normalize=False).fit(X, y)
687    >>> reg.score(X, y)
688    0.9991...
689    >>> reg.predict(X[:1,])
690    array([-78.3854...])
691    """
692
693    def __init__(
694        self,
695        *,
696        n_nonzero_coefs=None,
697        tol=None,
698        fit_intercept=True,
699        normalize="deprecated",
700        precompute="auto",
701    ):
702        self.n_nonzero_coefs = n_nonzero_coefs
703        self.tol = tol
704        self.fit_intercept = fit_intercept
705        self.normalize = normalize
706        self.precompute = precompute
707
708    def fit(self, X, y):
709        """Fit the model using X, y as training data.
710
711        Parameters
712        ----------
713        X : array-like of shape (n_samples, n_features)
714            Training data.
715
716        y : array-like of shape (n_samples,) or (n_samples, n_targets)
717            Target values. Will be cast to X's dtype if necessary.
718
719        Returns
720        -------
721        self : object
722            Returns an instance of self.
723        """
724        _normalize = _deprecate_normalize(
725            self.normalize, default=True, estimator_name=self.__class__.__name__
726        )
727
728        X, y = self._validate_data(X, y, multi_output=True, y_numeric=True)
729        n_features = X.shape[1]
730
731        X, y, X_offset, y_offset, X_scale, Gram, Xy = _pre_fit(
732            X, y, None, self.precompute, _normalize, self.fit_intercept, copy=True
733        )
734
735        if y.ndim == 1:
736            y = y[:, np.newaxis]
737
738        if self.n_nonzero_coefs is None and self.tol is None:
739            # default for n_nonzero_coefs is 0.1 * n_features
740            # but at least one.
741            self.n_nonzero_coefs_ = max(int(0.1 * n_features), 1)
742        else:
743            self.n_nonzero_coefs_ = self.n_nonzero_coefs
744
745        if Gram is False:
746            coef_, self.n_iter_ = orthogonal_mp(
747                X,
748                y,
749                n_nonzero_coefs=self.n_nonzero_coefs_,
750                tol=self.tol,
751                precompute=False,
752                copy_X=True,
753                return_n_iter=True,
754            )
755        else:
756            norms_sq = np.sum(y ** 2, axis=0) if self.tol is not None else None
757
758            coef_, self.n_iter_ = orthogonal_mp_gram(
759                Gram,
760                Xy=Xy,
761                n_nonzero_coefs=self.n_nonzero_coefs_,
762                tol=self.tol,
763                norms_squared=norms_sq,
764                copy_Gram=True,
765                copy_Xy=True,
766                return_n_iter=True,
767            )
768        self.coef_ = coef_.T
769        self._set_intercept(X_offset, y_offset, X_scale)
770        return self
771
772
773def _omp_path_residues(
774    X_train,
775    y_train,
776    X_test,
777    y_test,
778    copy=True,
779    fit_intercept=True,
780    normalize=True,
781    max_iter=100,
782):
783    """Compute the residues on left-out data for a full LARS path.
784
785    Parameters
786    ----------
787    X_train : ndarray of shape (n_samples, n_features)
788        The data to fit the LARS on.
789
790    y_train : ndarray of shape (n_samples)
791        The target variable to fit LARS on.
792
793    X_test : ndarray of shape (n_samples, n_features)
794        The data to compute the residues on.
795
796    y_test : ndarray of shape (n_samples)
797        The target variable to compute the residues on.
798
799    copy : bool, default=True
800        Whether X_train, X_test, y_train and y_test should be copied.  If
801        False, they may be overwritten.
802
803    fit_intercept : bool, default=True
804        Whether to calculate the intercept for this model. If set
805        to false, no intercept will be used in calculations
806        (i.e. data is expected to be centered).
807
808    normalize : bool, default=True
809        This parameter is ignored when ``fit_intercept`` is set to False.
810        If True, the regressors X will be normalized before regression by
811        subtracting the mean and dividing by the l2-norm.
812        If you wish to standardize, please use
813        :class:`~sklearn.preprocessing.StandardScaler` before calling ``fit``
814        on an estimator with ``normalize=False``.
815
816        .. deprecated:: 1.0
817            ``normalize`` was deprecated in version 1.0. It will default
818            to False in 1.2 and be removed in 1.4.
819
820    max_iter : int, default=100
821        Maximum numbers of iterations to perform, therefore maximum features
822        to include. 100 by default.
823
824    Returns
825    -------
826    residues : ndarray of shape (n_samples, max_features)
827        Residues of the prediction on the test data.
828    """
829
830    if copy:
831        X_train = X_train.copy()
832        y_train = y_train.copy()
833        X_test = X_test.copy()
834        y_test = y_test.copy()
835
836    if fit_intercept:
837        X_mean = X_train.mean(axis=0)
838        X_train -= X_mean
839        X_test -= X_mean
840        y_mean = y_train.mean(axis=0)
841        y_train = as_float_array(y_train, copy=False)
842        y_train -= y_mean
843        y_test = as_float_array(y_test, copy=False)
844        y_test -= y_mean
845
846    if normalize:
847        norms = np.sqrt(np.sum(X_train ** 2, axis=0))
848        nonzeros = np.flatnonzero(norms)
849        X_train[:, nonzeros] /= norms[nonzeros]
850
851    coefs = orthogonal_mp(
852        X_train,
853        y_train,
854        n_nonzero_coefs=max_iter,
855        tol=None,
856        precompute=False,
857        copy_X=False,
858        return_path=True,
859    )
860    if coefs.ndim == 1:
861        coefs = coefs[:, np.newaxis]
862    if normalize:
863        coefs[nonzeros] /= norms[nonzeros][:, np.newaxis]
864
865    return np.dot(coefs.T, X_test.T) - y_test
866
867
868class OrthogonalMatchingPursuitCV(RegressorMixin, LinearModel):
869    """Cross-validated Orthogonal Matching Pursuit model (OMP).
870
871    See glossary entry for :term:`cross-validation estimator`.
872
873    Read more in the :ref:`User Guide <omp>`.
874
875    Parameters
876    ----------
877    copy : bool, default=True
878        Whether the design matrix X must be copied by the algorithm. A false
879        value is only helpful if X is already Fortran-ordered, otherwise a
880        copy is made anyway.
881
882    fit_intercept : bool, default=True
883        Whether to calculate the intercept for this model. If set
884        to false, no intercept will be used in calculations
885        (i.e. data is expected to be centered).
886
887    normalize : bool, default=True
888        This parameter is ignored when ``fit_intercept`` is set to False.
889        If True, the regressors X will be normalized before regression by
890        subtracting the mean and dividing by the l2-norm.
891        If you wish to standardize, please use
892        :class:`~sklearn.preprocessing.StandardScaler` before calling ``fit``
893        on an estimator with ``normalize=False``.
894
895        .. deprecated:: 1.0
896            ``normalize`` was deprecated in version 1.0. It will default
897            to False in 1.2 and be removed in 1.4.
898
899    max_iter : int, default=None
900        Maximum numbers of iterations to perform, therefore maximum features
901        to include. 10% of ``n_features`` but at least 5 if available.
902
903    cv : int, cross-validation generator or iterable, default=None
904        Determines the cross-validation splitting strategy.
905        Possible inputs for cv are:
906
907        - None, to use the default 5-fold cross-validation,
908        - integer, to specify the number of folds.
909        - :term:`CV splitter`,
910        - An iterable yielding (train, test) splits as arrays of indices.
911
912        For integer/None inputs, :class:`KFold` is used.
913
914        Refer :ref:`User Guide <cross_validation>` for the various
915        cross-validation strategies that can be used here.
916
917        .. versionchanged:: 0.22
918            ``cv`` default value if None changed from 3-fold to 5-fold.
919
920    n_jobs : int, default=None
921        Number of CPUs to use during the cross validation.
922        ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
923        ``-1`` means using all processors. See :term:`Glossary <n_jobs>`
924        for more details.
925
926    verbose : bool or int, default=False
927        Sets the verbosity amount.
928
929    Attributes
930    ----------
931    intercept_ : float or ndarray of shape (n_targets,)
932        Independent term in decision function.
933
934    coef_ : ndarray of shape (n_features,) or (n_targets, n_features)
935        Parameter vector (w in the problem formulation).
936
937    n_nonzero_coefs_ : int
938        Estimated number of non-zero coefficients giving the best mean squared
939        error over the cross-validation folds.
940
941    n_iter_ : int or array-like
942        Number of active features across every target for the model refit with
943        the best hyperparameters got by cross-validating across all folds.
944
945    n_features_in_ : int
946        Number of features seen during :term:`fit`.
947
948        .. versionadded:: 0.24
949
950    feature_names_in_ : ndarray of shape (`n_features_in_`,)
951        Names of features seen during :term:`fit`. Defined only when `X`
952        has feature names that are all strings.
953
954        .. versionadded:: 1.0
955
956    See Also
957    --------
958    orthogonal_mp : Solves n_targets Orthogonal Matching Pursuit problems.
959    orthogonal_mp_gram : Solves n_targets Orthogonal Matching Pursuit
960        problems using only the Gram matrix X.T * X and the product X.T * y.
961    lars_path : Compute Least Angle Regression or Lasso path using LARS algorithm.
962    Lars : Least Angle Regression model a.k.a. LAR.
963    LassoLars : Lasso model fit with Least Angle Regression a.k.a. Lars.
964    OrthogonalMatchingPursuit : Orthogonal Matching Pursuit model (OMP).
965    LarsCV : Cross-validated Least Angle Regression model.
966    LassoLarsCV : Cross-validated Lasso model fit with Least Angle Regression.
967    sklearn.decomposition.sparse_encode : Generic sparse coding.
968        Each column of the result is the solution to a Lasso problem.
969
970    Examples
971    --------
972    >>> from sklearn.linear_model import OrthogonalMatchingPursuitCV
973    >>> from sklearn.datasets import make_regression
974    >>> X, y = make_regression(n_features=100, n_informative=10,
975    ...                        noise=4, random_state=0)
976    >>> reg = OrthogonalMatchingPursuitCV(cv=5, normalize=False).fit(X, y)
977    >>> reg.score(X, y)
978    0.9991...
979    >>> reg.n_nonzero_coefs_
980    10
981    >>> reg.predict(X[:1,])
982    array([-78.3854...])
983    """
984
985    def __init__(
986        self,
987        *,
988        copy=True,
989        fit_intercept=True,
990        normalize="deprecated",
991        max_iter=None,
992        cv=None,
993        n_jobs=None,
994        verbose=False,
995    ):
996        self.copy = copy
997        self.fit_intercept = fit_intercept
998        self.normalize = normalize
999        self.max_iter = max_iter
1000        self.cv = cv
1001        self.n_jobs = n_jobs
1002        self.verbose = verbose
1003
1004    def fit(self, X, y):
1005        """Fit the model using X, y as training data.
1006
1007        Parameters
1008        ----------
1009        X : array-like of shape (n_samples, n_features)
1010            Training data.
1011
1012        y : array-like of shape (n_samples,)
1013            Target values. Will be cast to X's dtype if necessary.
1014
1015        Returns
1016        -------
1017        self : object
1018            Returns an instance of self.
1019        """
1020
1021        _normalize = _deprecate_normalize(
1022            self.normalize, default=True, estimator_name=self.__class__.__name__
1023        )
1024
1025        X, y = self._validate_data(
1026            X, y, y_numeric=True, ensure_min_features=2, estimator=self
1027        )
1028        X = as_float_array(X, copy=False, force_all_finite=False)
1029        cv = check_cv(self.cv, classifier=False)
1030        max_iter = (
1031            min(max(int(0.1 * X.shape[1]), 5), X.shape[1])
1032            if not self.max_iter
1033            else self.max_iter
1034        )
1035        cv_paths = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)(
1036            delayed(_omp_path_residues)(
1037                X[train],
1038                y[train],
1039                X[test],
1040                y[test],
1041                self.copy,
1042                self.fit_intercept,
1043                _normalize,
1044                max_iter,
1045            )
1046            for train, test in cv.split(X)
1047        )
1048
1049        min_early_stop = min(fold.shape[0] for fold in cv_paths)
1050        mse_folds = np.array(
1051            [(fold[:min_early_stop] ** 2).mean(axis=1) for fold in cv_paths]
1052        )
1053        best_n_nonzero_coefs = np.argmin(mse_folds.mean(axis=0)) + 1
1054        self.n_nonzero_coefs_ = best_n_nonzero_coefs
1055        omp = OrthogonalMatchingPursuit(
1056            n_nonzero_coefs=best_n_nonzero_coefs,
1057            fit_intercept=self.fit_intercept,
1058            normalize=_normalize,
1059        )
1060        omp.fit(X, y)
1061        self.coef_ = omp.coef_
1062        self.intercept_ = omp.intercept_
1063        self.n_iter_ = omp.n_iter_
1064        return self
1065