1"""
2The :mod:`sklearn.model_selection._split` module includes classes and
3functions to split the data based on a preset strategy.
4"""
5
6# Author: Alexandre Gramfort <alexandre.gramfort@inria.fr>
7#         Gael Varoquaux <gael.varoquaux@normalesup.org>
8#         Olivier Grisel <olivier.grisel@ensta.org>
9#         Raghav RV <rvraghav93@gmail.com>
10#         Leandro Hermida <hermidal@cs.umd.edu>
11#         Rodion Martynov <marrodion@gmail.com>
12# License: BSD 3 clause
13
14from collections.abc import Iterable
15from collections import defaultdict
16import warnings
17from itertools import chain, combinations
18from math import ceil, floor
19import numbers
20from abc import ABCMeta, abstractmethod
21from inspect import signature
22
23import numpy as np
24from scipy.special import comb
25
26from ..utils import indexable, check_random_state, _safe_indexing
27from ..utils import _approximate_mode
28from ..utils.validation import _num_samples, column_or_1d
29from ..utils.validation import check_array
30from ..utils.multiclass import type_of_target
31from ..base import _pprint
32
33__all__ = [
34    "BaseCrossValidator",
35    "KFold",
36    "GroupKFold",
37    "LeaveOneGroupOut",
38    "LeaveOneOut",
39    "LeavePGroupsOut",
40    "LeavePOut",
41    "RepeatedStratifiedKFold",
42    "RepeatedKFold",
43    "ShuffleSplit",
44    "GroupShuffleSplit",
45    "StratifiedKFold",
46    "StratifiedGroupKFold",
47    "StratifiedShuffleSplit",
48    "PredefinedSplit",
49    "train_test_split",
50    "check_cv",
51]
52
53
54class BaseCrossValidator(metaclass=ABCMeta):
55    """Base class for all cross-validators
56
57    Implementations must define `_iter_test_masks` or `_iter_test_indices`.
58    """
59
60    def split(self, X, y=None, groups=None):
61        """Generate indices to split data into training and test set.
62
63        Parameters
64        ----------
65        X : array-like of shape (n_samples, n_features)
66            Training data, where `n_samples` is the number of samples
67            and `n_features` is the number of features.
68
69        y : array-like of shape (n_samples,)
70            The target variable for supervised learning problems.
71
72        groups : array-like of shape (n_samples,), default=None
73            Group labels for the samples used while splitting the dataset into
74            train/test set.
75
76        Yields
77        ------
78        train : ndarray
79            The training set indices for that split.
80
81        test : ndarray
82            The testing set indices for that split.
83        """
84        X, y, groups = indexable(X, y, groups)
85        indices = np.arange(_num_samples(X))
86        for test_index in self._iter_test_masks(X, y, groups):
87            train_index = indices[np.logical_not(test_index)]
88            test_index = indices[test_index]
89            yield train_index, test_index
90
91    # Since subclasses must implement either _iter_test_masks or
92    # _iter_test_indices, neither can be abstract.
93    def _iter_test_masks(self, X=None, y=None, groups=None):
94        """Generates boolean masks corresponding to test sets.
95
96        By default, delegates to _iter_test_indices(X, y, groups)
97        """
98        for test_index in self._iter_test_indices(X, y, groups):
99            test_mask = np.zeros(_num_samples(X), dtype=bool)
100            test_mask[test_index] = True
101            yield test_mask
102
103    def _iter_test_indices(self, X=None, y=None, groups=None):
104        """Generates integer indices corresponding to test sets."""
105        raise NotImplementedError
106
107    @abstractmethod
108    def get_n_splits(self, X=None, y=None, groups=None):
109        """Returns the number of splitting iterations in the cross-validator"""
110
111    def __repr__(self):
112        return _build_repr(self)
113
114
115class LeaveOneOut(BaseCrossValidator):
116    """Leave-One-Out cross-validator
117
118    Provides train/test indices to split data in train/test sets. Each
119    sample is used once as a test set (singleton) while the remaining
120    samples form the training set.
121
122    Note: ``LeaveOneOut()`` is equivalent to ``KFold(n_splits=n)`` and
123    ``LeavePOut(p=1)`` where ``n`` is the number of samples.
124
125    Due to the high number of test sets (which is the same as the
126    number of samples) this cross-validation method can be very costly.
127    For large datasets one should favor :class:`KFold`, :class:`ShuffleSplit`
128    or :class:`StratifiedKFold`.
129
130    Read more in the :ref:`User Guide <leave_one_out>`.
131
132    Examples
133    --------
134    >>> import numpy as np
135    >>> from sklearn.model_selection import LeaveOneOut
136    >>> X = np.array([[1, 2], [3, 4]])
137    >>> y = np.array([1, 2])
138    >>> loo = LeaveOneOut()
139    >>> loo.get_n_splits(X)
140    2
141    >>> print(loo)
142    LeaveOneOut()
143    >>> for train_index, test_index in loo.split(X):
144    ...     print("TRAIN:", train_index, "TEST:", test_index)
145    ...     X_train, X_test = X[train_index], X[test_index]
146    ...     y_train, y_test = y[train_index], y[test_index]
147    ...     print(X_train, X_test, y_train, y_test)
148    TRAIN: [1] TEST: [0]
149    [[3 4]] [[1 2]] [2] [1]
150    TRAIN: [0] TEST: [1]
151    [[1 2]] [[3 4]] [1] [2]
152
153    See Also
154    --------
155    LeaveOneGroupOut : For splitting the data according to explicit,
156        domain-specific stratification of the dataset.
157    GroupKFold : K-fold iterator variant with non-overlapping groups.
158    """
159
160    def _iter_test_indices(self, X, y=None, groups=None):
161        n_samples = _num_samples(X)
162        if n_samples <= 1:
163            raise ValueError(
164                "Cannot perform LeaveOneOut with n_samples={}.".format(n_samples)
165            )
166        return range(n_samples)
167
168    def get_n_splits(self, X, y=None, groups=None):
169        """Returns the number of splitting iterations in the cross-validator
170
171        Parameters
172        ----------
173        X : array-like of shape (n_samples, n_features)
174            Training data, where `n_samples` is the number of samples
175            and `n_features` is the number of features.
176
177        y : object
178            Always ignored, exists for compatibility.
179
180        groups : object
181            Always ignored, exists for compatibility.
182
183        Returns
184        -------
185        n_splits : int
186            Returns the number of splitting iterations in the cross-validator.
187        """
188        if X is None:
189            raise ValueError("The 'X' parameter should not be None.")
190        return _num_samples(X)
191
192
193class LeavePOut(BaseCrossValidator):
194    """Leave-P-Out cross-validator
195
196    Provides train/test indices to split data in train/test sets. This results
197    in testing on all distinct samples of size p, while the remaining n - p
198    samples form the training set in each iteration.
199
200    Note: ``LeavePOut(p)`` is NOT equivalent to
201    ``KFold(n_splits=n_samples // p)`` which creates non-overlapping test sets.
202
203    Due to the high number of iterations which grows combinatorically with the
204    number of samples this cross-validation method can be very costly. For
205    large datasets one should favor :class:`KFold`, :class:`StratifiedKFold`
206    or :class:`ShuffleSplit`.
207
208    Read more in the :ref:`User Guide <leave_p_out>`.
209
210    Parameters
211    ----------
212    p : int
213        Size of the test sets. Must be strictly less than the number of
214        samples.
215
216    Examples
217    --------
218    >>> import numpy as np
219    >>> from sklearn.model_selection import LeavePOut
220    >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
221    >>> y = np.array([1, 2, 3, 4])
222    >>> lpo = LeavePOut(2)
223    >>> lpo.get_n_splits(X)
224    6
225    >>> print(lpo)
226    LeavePOut(p=2)
227    >>> for train_index, test_index in lpo.split(X):
228    ...     print("TRAIN:", train_index, "TEST:", test_index)
229    ...     X_train, X_test = X[train_index], X[test_index]
230    ...     y_train, y_test = y[train_index], y[test_index]
231    TRAIN: [2 3] TEST: [0 1]
232    TRAIN: [1 3] TEST: [0 2]
233    TRAIN: [1 2] TEST: [0 3]
234    TRAIN: [0 3] TEST: [1 2]
235    TRAIN: [0 2] TEST: [1 3]
236    TRAIN: [0 1] TEST: [2 3]
237    """
238
239    def __init__(self, p):
240        self.p = p
241
242    def _iter_test_indices(self, X, y=None, groups=None):
243        n_samples = _num_samples(X)
244        if n_samples <= self.p:
245            raise ValueError(
246                "p={} must be strictly less than the number of samples={}".format(
247                    self.p, n_samples
248                )
249            )
250        for combination in combinations(range(n_samples), self.p):
251            yield np.array(combination)
252
253    def get_n_splits(self, X, y=None, groups=None):
254        """Returns the number of splitting iterations in the cross-validator
255
256        Parameters
257        ----------
258        X : array-like of shape (n_samples, n_features)
259            Training data, where `n_samples` is the number of samples
260            and `n_features` is the number of features.
261
262        y : object
263            Always ignored, exists for compatibility.
264
265        groups : object
266            Always ignored, exists for compatibility.
267        """
268        if X is None:
269            raise ValueError("The 'X' parameter should not be None.")
270        return int(comb(_num_samples(X), self.p, exact=True))
271
272
273class _BaseKFold(BaseCrossValidator, metaclass=ABCMeta):
274    """Base class for KFold, GroupKFold, and StratifiedKFold"""
275
276    @abstractmethod
277    def __init__(self, n_splits, *, shuffle, random_state):
278        if not isinstance(n_splits, numbers.Integral):
279            raise ValueError(
280                "The number of folds must be of Integral type. "
281                "%s of type %s was passed." % (n_splits, type(n_splits))
282            )
283        n_splits = int(n_splits)
284
285        if n_splits <= 1:
286            raise ValueError(
287                "k-fold cross-validation requires at least one"
288                " train/test split by setting n_splits=2 or more,"
289                " got n_splits={0}.".format(n_splits)
290            )
291
292        if not isinstance(shuffle, bool):
293            raise TypeError("shuffle must be True or False; got {0}".format(shuffle))
294
295        if not shuffle and random_state is not None:  # None is the default
296            raise ValueError(
297                "Setting a random_state has no effect since shuffle is "
298                "False. You should leave "
299                "random_state to its default (None), or set shuffle=True.",
300            )
301
302        self.n_splits = n_splits
303        self.shuffle = shuffle
304        self.random_state = random_state
305
306    def split(self, X, y=None, groups=None):
307        """Generate indices to split data into training and test set.
308
309        Parameters
310        ----------
311        X : array-like of shape (n_samples, n_features)
312            Training data, where `n_samples` is the number of samples
313            and `n_features` is the number of features.
314
315        y : array-like of shape (n_samples,), default=None
316            The target variable for supervised learning problems.
317
318        groups : array-like of shape (n_samples,), default=None
319            Group labels for the samples used while splitting the dataset into
320            train/test set.
321
322        Yields
323        ------
324        train : ndarray
325            The training set indices for that split.
326
327        test : ndarray
328            The testing set indices for that split.
329        """
330        X, y, groups = indexable(X, y, groups)
331        n_samples = _num_samples(X)
332        if self.n_splits > n_samples:
333            raise ValueError(
334                (
335                    "Cannot have number of splits n_splits={0} greater"
336                    " than the number of samples: n_samples={1}."
337                ).format(self.n_splits, n_samples)
338            )
339
340        for train, test in super().split(X, y, groups):
341            yield train, test
342
343    def get_n_splits(self, X=None, y=None, groups=None):
344        """Returns the number of splitting iterations in the cross-validator
345
346        Parameters
347        ----------
348        X : object
349            Always ignored, exists for compatibility.
350
351        y : object
352            Always ignored, exists for compatibility.
353
354        groups : object
355            Always ignored, exists for compatibility.
356
357        Returns
358        -------
359        n_splits : int
360            Returns the number of splitting iterations in the cross-validator.
361        """
362        return self.n_splits
363
364
365class KFold(_BaseKFold):
366    """K-Folds cross-validator
367
368    Provides train/test indices to split data in train/test sets. Split
369    dataset into k consecutive folds (without shuffling by default).
370
371    Each fold is then used once as a validation while the k - 1 remaining
372    folds form the training set.
373
374    Read more in the :ref:`User Guide <k_fold>`.
375
376    Parameters
377    ----------
378    n_splits : int, default=5
379        Number of folds. Must be at least 2.
380
381        .. versionchanged:: 0.22
382            ``n_splits`` default value changed from 3 to 5.
383
384    shuffle : bool, default=False
385        Whether to shuffle the data before splitting into batches.
386        Note that the samples within each split will not be shuffled.
387
388    random_state : int, RandomState instance or None, default=None
389        When `shuffle` is True, `random_state` affects the ordering of the
390        indices, which controls the randomness of each fold. Otherwise, this
391        parameter has no effect.
392        Pass an int for reproducible output across multiple function calls.
393        See :term:`Glossary <random_state>`.
394
395    Examples
396    --------
397    >>> import numpy as np
398    >>> from sklearn.model_selection import KFold
399    >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
400    >>> y = np.array([1, 2, 3, 4])
401    >>> kf = KFold(n_splits=2)
402    >>> kf.get_n_splits(X)
403    2
404    >>> print(kf)
405    KFold(n_splits=2, random_state=None, shuffle=False)
406    >>> for train_index, test_index in kf.split(X):
407    ...     print("TRAIN:", train_index, "TEST:", test_index)
408    ...     X_train, X_test = X[train_index], X[test_index]
409    ...     y_train, y_test = y[train_index], y[test_index]
410    TRAIN: [2 3] TEST: [0 1]
411    TRAIN: [0 1] TEST: [2 3]
412
413    Notes
414    -----
415    The first ``n_samples % n_splits`` folds have size
416    ``n_samples // n_splits + 1``, other folds have size
417    ``n_samples // n_splits``, where ``n_samples`` is the number of samples.
418
419    Randomized CV splitters may return different results for each call of
420    split. You can make the results identical by setting `random_state`
421    to an integer.
422
423    See Also
424    --------
425    StratifiedKFold : Takes group information into account to avoid building
426        folds with imbalanced class distributions (for binary or multiclass
427        classification tasks).
428
429    GroupKFold : K-fold iterator variant with non-overlapping groups.
430
431    RepeatedKFold : Repeats K-Fold n times.
432    """
433
434    def __init__(self, n_splits=5, *, shuffle=False, random_state=None):
435        super().__init__(n_splits=n_splits, shuffle=shuffle, random_state=random_state)
436
437    def _iter_test_indices(self, X, y=None, groups=None):
438        n_samples = _num_samples(X)
439        indices = np.arange(n_samples)
440        if self.shuffle:
441            check_random_state(self.random_state).shuffle(indices)
442
443        n_splits = self.n_splits
444        fold_sizes = np.full(n_splits, n_samples // n_splits, dtype=int)
445        fold_sizes[: n_samples % n_splits] += 1
446        current = 0
447        for fold_size in fold_sizes:
448            start, stop = current, current + fold_size
449            yield indices[start:stop]
450            current = stop
451
452
453class GroupKFold(_BaseKFold):
454    """K-fold iterator variant with non-overlapping groups.
455
456    The same group will not appear in two different folds (the number of
457    distinct groups has to be at least equal to the number of folds).
458
459    The folds are approximately balanced in the sense that the number of
460    distinct groups is approximately the same in each fold.
461
462    Read more in the :ref:`User Guide <group_k_fold>`.
463
464    Parameters
465    ----------
466    n_splits : int, default=5
467        Number of folds. Must be at least 2.
468
469        .. versionchanged:: 0.22
470            ``n_splits`` default value changed from 3 to 5.
471
472    Examples
473    --------
474    >>> import numpy as np
475    >>> from sklearn.model_selection import GroupKFold
476    >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
477    >>> y = np.array([1, 2, 3, 4])
478    >>> groups = np.array([0, 0, 2, 2])
479    >>> group_kfold = GroupKFold(n_splits=2)
480    >>> group_kfold.get_n_splits(X, y, groups)
481    2
482    >>> print(group_kfold)
483    GroupKFold(n_splits=2)
484    >>> for train_index, test_index in group_kfold.split(X, y, groups):
485    ...     print("TRAIN:", train_index, "TEST:", test_index)
486    ...     X_train, X_test = X[train_index], X[test_index]
487    ...     y_train, y_test = y[train_index], y[test_index]
488    ...     print(X_train, X_test, y_train, y_test)
489    ...
490    TRAIN: [0 1] TEST: [2 3]
491    [[1 2]
492     [3 4]] [[5 6]
493     [7 8]] [1 2] [3 4]
494    TRAIN: [2 3] TEST: [0 1]
495    [[5 6]
496     [7 8]] [[1 2]
497     [3 4]] [3 4] [1 2]
498
499    See Also
500    --------
501    LeaveOneGroupOut : For splitting the data according to explicit
502        domain-specific stratification of the dataset.
503    """
504
505    def __init__(self, n_splits=5):
506        super().__init__(n_splits, shuffle=False, random_state=None)
507
508    def _iter_test_indices(self, X, y, groups):
509        if groups is None:
510            raise ValueError("The 'groups' parameter should not be None.")
511        groups = check_array(groups, ensure_2d=False, dtype=None)
512
513        unique_groups, groups = np.unique(groups, return_inverse=True)
514        n_groups = len(unique_groups)
515
516        if self.n_splits > n_groups:
517            raise ValueError(
518                "Cannot have number of splits n_splits=%d greater"
519                " than the number of groups: %d." % (self.n_splits, n_groups)
520            )
521
522        # Weight groups by their number of occurrences
523        n_samples_per_group = np.bincount(groups)
524
525        # Distribute the most frequent groups first
526        indices = np.argsort(n_samples_per_group)[::-1]
527        n_samples_per_group = n_samples_per_group[indices]
528
529        # Total weight of each fold
530        n_samples_per_fold = np.zeros(self.n_splits)
531
532        # Mapping from group index to fold index
533        group_to_fold = np.zeros(len(unique_groups))
534
535        # Distribute samples by adding the largest weight to the lightest fold
536        for group_index, weight in enumerate(n_samples_per_group):
537            lightest_fold = np.argmin(n_samples_per_fold)
538            n_samples_per_fold[lightest_fold] += weight
539            group_to_fold[indices[group_index]] = lightest_fold
540
541        indices = group_to_fold[groups]
542
543        for f in range(self.n_splits):
544            yield np.where(indices == f)[0]
545
546    def split(self, X, y=None, groups=None):
547        """Generate indices to split data into training and test set.
548
549        Parameters
550        ----------
551        X : array-like of shape (n_samples, n_features)
552            Training data, where `n_samples` is the number of samples
553            and `n_features` is the number of features.
554
555        y : array-like of shape (n_samples,), default=None
556            The target variable for supervised learning problems.
557
558        groups : array-like of shape (n_samples,)
559            Group labels for the samples used while splitting the dataset into
560            train/test set.
561
562        Yields
563        ------
564        train : ndarray
565            The training set indices for that split.
566
567        test : ndarray
568            The testing set indices for that split.
569        """
570        return super().split(X, y, groups)
571
572
573class StratifiedKFold(_BaseKFold):
574    """Stratified K-Folds cross-validator.
575
576    Provides train/test indices to split data in train/test sets.
577
578    This cross-validation object is a variation of KFold that returns
579    stratified folds. The folds are made by preserving the percentage of
580    samples for each class.
581
582    Read more in the :ref:`User Guide <stratified_k_fold>`.
583
584    Parameters
585    ----------
586    n_splits : int, default=5
587        Number of folds. Must be at least 2.
588
589        .. versionchanged:: 0.22
590            ``n_splits`` default value changed from 3 to 5.
591
592    shuffle : bool, default=False
593        Whether to shuffle each class's samples before splitting into batches.
594        Note that the samples within each split will not be shuffled.
595
596    random_state : int, RandomState instance or None, default=None
597        When `shuffle` is True, `random_state` affects the ordering of the
598        indices, which controls the randomness of each fold for each class.
599        Otherwise, leave `random_state` as `None`.
600        Pass an int for reproducible output across multiple function calls.
601        See :term:`Glossary <random_state>`.
602
603    Examples
604    --------
605    >>> import numpy as np
606    >>> from sklearn.model_selection import StratifiedKFold
607    >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
608    >>> y = np.array([0, 0, 1, 1])
609    >>> skf = StratifiedKFold(n_splits=2)
610    >>> skf.get_n_splits(X, y)
611    2
612    >>> print(skf)
613    StratifiedKFold(n_splits=2, random_state=None, shuffle=False)
614    >>> for train_index, test_index in skf.split(X, y):
615    ...     print("TRAIN:", train_index, "TEST:", test_index)
616    ...     X_train, X_test = X[train_index], X[test_index]
617    ...     y_train, y_test = y[train_index], y[test_index]
618    TRAIN: [1 3] TEST: [0 2]
619    TRAIN: [0 2] TEST: [1 3]
620
621    Notes
622    -----
623    The implementation is designed to:
624
625    * Generate test sets such that all contain the same distribution of
626      classes, or as close as possible.
627    * Be invariant to class label: relabelling ``y = ["Happy", "Sad"]`` to
628      ``y = [1, 0]`` should not change the indices generated.
629    * Preserve order dependencies in the dataset ordering, when
630      ``shuffle=False``: all samples from class k in some test set were
631      contiguous in y, or separated in y by samples from classes other than k.
632    * Generate test sets where the smallest and largest differ by at most one
633      sample.
634
635    .. versionchanged:: 0.22
636        The previous implementation did not follow the last constraint.
637
638    See Also
639    --------
640    RepeatedStratifiedKFold : Repeats Stratified K-Fold n times.
641    """
642
643    def __init__(self, n_splits=5, *, shuffle=False, random_state=None):
644        super().__init__(n_splits=n_splits, shuffle=shuffle, random_state=random_state)
645
646    def _make_test_folds(self, X, y=None):
647        rng = check_random_state(self.random_state)
648        y = np.asarray(y)
649        type_of_target_y = type_of_target(y)
650        allowed_target_types = ("binary", "multiclass")
651        if type_of_target_y not in allowed_target_types:
652            raise ValueError(
653                "Supported target types are: {}. Got {!r} instead.".format(
654                    allowed_target_types, type_of_target_y
655                )
656            )
657
658        y = column_or_1d(y)
659
660        _, y_idx, y_inv = np.unique(y, return_index=True, return_inverse=True)
661        # y_inv encodes y according to lexicographic order. We invert y_idx to
662        # map the classes so that they are encoded by order of appearance:
663        # 0 represents the first label appearing in y, 1 the second, etc.
664        _, class_perm = np.unique(y_idx, return_inverse=True)
665        y_encoded = class_perm[y_inv]
666
667        n_classes = len(y_idx)
668        y_counts = np.bincount(y_encoded)
669        min_groups = np.min(y_counts)
670        if np.all(self.n_splits > y_counts):
671            raise ValueError(
672                "n_splits=%d cannot be greater than the"
673                " number of members in each class." % (self.n_splits)
674            )
675        if self.n_splits > min_groups:
676            warnings.warn(
677                "The least populated class in y has only %d"
678                " members, which is less than n_splits=%d."
679                % (min_groups, self.n_splits),
680                UserWarning,
681            )
682
683        # Determine the optimal number of samples from each class in each fold,
684        # using round robin over the sorted y. (This can be done direct from
685        # counts, but that code is unreadable.)
686        y_order = np.sort(y_encoded)
687        allocation = np.asarray(
688            [
689                np.bincount(y_order[i :: self.n_splits], minlength=n_classes)
690                for i in range(self.n_splits)
691            ]
692        )
693
694        # To maintain the data order dependencies as best as possible within
695        # the stratification constraint, we assign samples from each class in
696        # blocks (and then mess that up when shuffle=True).
697        test_folds = np.empty(len(y), dtype="i")
698        for k in range(n_classes):
699            # since the kth column of allocation stores the number of samples
700            # of class k in each test set, this generates blocks of fold
701            # indices corresponding to the allocation for class k.
702            folds_for_class = np.arange(self.n_splits).repeat(allocation[:, k])
703            if self.shuffle:
704                rng.shuffle(folds_for_class)
705            test_folds[y_encoded == k] = folds_for_class
706        return test_folds
707
708    def _iter_test_masks(self, X, y=None, groups=None):
709        test_folds = self._make_test_folds(X, y)
710        for i in range(self.n_splits):
711            yield test_folds == i
712
713    def split(self, X, y, groups=None):
714        """Generate indices to split data into training and test set.
715
716        Parameters
717        ----------
718        X : array-like of shape (n_samples, n_features)
719            Training data, where `n_samples` is the number of samples
720            and `n_features` is the number of features.
721
722            Note that providing ``y`` is sufficient to generate the splits and
723            hence ``np.zeros(n_samples)`` may be used as a placeholder for
724            ``X`` instead of actual training data.
725
726        y : array-like of shape (n_samples,)
727            The target variable for supervised learning problems.
728            Stratification is done based on the y labels.
729
730        groups : object
731            Always ignored, exists for compatibility.
732
733        Yields
734        ------
735        train : ndarray
736            The training set indices for that split.
737
738        test : ndarray
739            The testing set indices for that split.
740
741        Notes
742        -----
743        Randomized CV splitters may return different results for each call of
744        split. You can make the results identical by setting `random_state`
745        to an integer.
746        """
747        y = check_array(y, ensure_2d=False, dtype=None)
748        return super().split(X, y, groups)
749
750
751class StratifiedGroupKFold(_BaseKFold):
752    """Stratified K-Folds iterator variant with non-overlapping groups.
753
754    This cross-validation object is a variation of StratifiedKFold attempts to
755    return stratified folds with non-overlapping groups. The folds are made by
756    preserving the percentage of samples for each class.
757
758    The same group will not appear in two different folds (the number of
759    distinct groups has to be at least equal to the number of folds).
760
761    The difference between GroupKFold and StratifiedGroupKFold is that
762    the former attempts to create balanced folds such that the number of
763    distinct groups is approximately the same in each fold, whereas
764    StratifiedGroupKFold attempts to create folds which preserve the
765    percentage of samples for each class as much as possible given the
766    constraint of non-overlapping groups between splits.
767
768    Read more in the :ref:`User Guide <cross_validation>`.
769
770    Parameters
771    ----------
772    n_splits : int, default=5
773        Number of folds. Must be at least 2.
774
775    shuffle : bool, default=False
776        Whether to shuffle each class's samples before splitting into batches.
777        Note that the samples within each split will not be shuffled.
778        This implementation can only shuffle groups that have approximately the
779        same y distribution, no global shuffle will be performed.
780
781    random_state : int or RandomState instance, default=None
782        When `shuffle` is True, `random_state` affects the ordering of the
783        indices, which controls the randomness of each fold for each class.
784        Otherwise, leave `random_state` as `None`.
785        Pass an int for reproducible output across multiple function calls.
786        See :term:`Glossary <random_state>`.
787
788    Examples
789    --------
790    >>> import numpy as np
791    >>> from sklearn.model_selection import StratifiedGroupKFold
792    >>> X = np.ones((17, 2))
793    >>> y = np.array([0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])
794    >>> groups = np.array([1, 1, 2, 2, 3, 3, 3, 4, 5, 5, 5, 5, 6, 6, 7, 8, 8])
795    >>> cv = StratifiedGroupKFold(n_splits=3)
796    >>> for train_idxs, test_idxs in cv.split(X, y, groups):
797    ...     print("TRAIN:", groups[train_idxs])
798    ...     print("      ", y[train_idxs])
799    ...     print(" TEST:", groups[test_idxs])
800    ...     print("      ", y[test_idxs])
801    TRAIN: [1 1 2 2 4 5 5 5 5 8 8]
802           [0 0 1 1 1 0 0 0 0 0 0]
803     TEST: [3 3 3 6 6 7]
804           [1 1 1 0 0 0]
805    TRAIN: [3 3 3 4 5 5 5 5 6 6 7]
806           [1 1 1 1 0 0 0 0 0 0 0]
807     TEST: [1 1 2 2 8 8]
808           [0 0 1 1 0 0]
809    TRAIN: [1 1 2 2 3 3 3 6 6 7 8 8]
810           [0 0 1 1 1 1 1 0 0 0 0 0]
811     TEST: [4 5 5 5 5]
812           [1 0 0 0 0]
813
814    Notes
815    -----
816    The implementation is designed to:
817
818    * Mimic the behavior of StratifiedKFold as much as possible for trivial
819      groups (e.g. when each group contains only one sample).
820    * Be invariant to class label: relabelling ``y = ["Happy", "Sad"]`` to
821      ``y = [1, 0]`` should not change the indices generated.
822    * Stratify based on samples as much as possible while keeping
823      non-overlapping groups constraint. That means that in some cases when
824      there is a small number of groups containing a large number of samples
825      the stratification will not be possible and the behavior will be close
826      to GroupKFold.
827
828    See also
829    --------
830    StratifiedKFold: Takes class information into account to build folds which
831        retain class distributions (for binary or multiclass classification
832        tasks).
833
834    GroupKFold: K-fold iterator variant with non-overlapping groups.
835    """
836
837    def __init__(self, n_splits=5, shuffle=False, random_state=None):
838        super().__init__(n_splits=n_splits, shuffle=shuffle, random_state=random_state)
839
840    def _iter_test_indices(self, X, y, groups):
841        # Implementation is based on this kaggle kernel:
842        # https://www.kaggle.com/jakubwasikowski/stratified-group-k-fold-cross-validation
843        # and is a subject to Apache 2.0 License. You may obtain a copy of the
844        # License at http://www.apache.org/licenses/LICENSE-2.0
845        # Changelist:
846        # - Refactored function to a class following scikit-learn KFold
847        #   interface.
848        # - Added heuristic for assigning group to the least populated fold in
849        #   cases when all other criteria are equal
850        # - Swtch from using python ``Counter`` to ``np.unique`` to get class
851        #   distribution
852        # - Added scikit-learn checks for input: checking that target is binary
853        #   or multiclass, checking passed random state, checking that number
854        #   of splits is less than number of members in each class, checking
855        #   that least populated class has more members than there are splits.
856        rng = check_random_state(self.random_state)
857        y = np.asarray(y)
858        type_of_target_y = type_of_target(y)
859        allowed_target_types = ("binary", "multiclass")
860        if type_of_target_y not in allowed_target_types:
861            raise ValueError(
862                "Supported target types are: {}. Got {!r} instead.".format(
863                    allowed_target_types, type_of_target_y
864                )
865            )
866
867        y = column_or_1d(y)
868        _, y_inv, y_cnt = np.unique(y, return_inverse=True, return_counts=True)
869        if np.all(self.n_splits > y_cnt):
870            raise ValueError(
871                "n_splits=%d cannot be greater than the"
872                " number of members in each class." % (self.n_splits)
873            )
874        n_smallest_class = np.min(y_cnt)
875        if self.n_splits > n_smallest_class:
876            warnings.warn(
877                "The least populated class in y has only %d"
878                " members, which is less than n_splits=%d."
879                % (n_smallest_class, self.n_splits),
880                UserWarning,
881            )
882        n_classes = len(y_cnt)
883
884        _, groups_inv, groups_cnt = np.unique(
885            groups, return_inverse=True, return_counts=True
886        )
887        y_counts_per_group = np.zeros((len(groups_cnt), n_classes))
888        for class_idx, group_idx in zip(y_inv, groups_inv):
889            y_counts_per_group[group_idx, class_idx] += 1
890
891        y_counts_per_fold = np.zeros((self.n_splits, n_classes))
892        groups_per_fold = defaultdict(set)
893
894        if self.shuffle:
895            rng.shuffle(y_counts_per_group)
896
897        # Stable sort to keep shuffled order for groups with the same
898        # class distribution variance
899        sorted_groups_idx = np.argsort(
900            -np.std(y_counts_per_group, axis=1), kind="mergesort"
901        )
902
903        for group_idx in sorted_groups_idx:
904            group_y_counts = y_counts_per_group[group_idx]
905            best_fold = self._find_best_fold(
906                y_counts_per_fold=y_counts_per_fold,
907                y_cnt=y_cnt,
908                group_y_counts=group_y_counts,
909            )
910            y_counts_per_fold[best_fold] += group_y_counts
911            groups_per_fold[best_fold].add(group_idx)
912
913        for i in range(self.n_splits):
914            test_indices = [
915                idx
916                for idx, group_idx in enumerate(groups_inv)
917                if group_idx in groups_per_fold[i]
918            ]
919            yield test_indices
920
921    def _find_best_fold(self, y_counts_per_fold, y_cnt, group_y_counts):
922        best_fold = None
923        min_eval = np.inf
924        min_samples_in_fold = np.inf
925        for i in range(self.n_splits):
926            y_counts_per_fold[i] += group_y_counts
927            # Summarise the distribution over classes in each proposed fold
928            std_per_class = np.std(y_counts_per_fold / y_cnt.reshape(1, -1), axis=0)
929            y_counts_per_fold[i] -= group_y_counts
930            fold_eval = np.mean(std_per_class)
931            samples_in_fold = np.sum(y_counts_per_fold[i])
932            is_current_fold_better = (
933                fold_eval < min_eval
934                or np.isclose(fold_eval, min_eval)
935                and samples_in_fold < min_samples_in_fold
936            )
937            if is_current_fold_better:
938                min_eval = fold_eval
939                min_samples_in_fold = samples_in_fold
940                best_fold = i
941        return best_fold
942
943
944class TimeSeriesSplit(_BaseKFold):
945    """Time Series cross-validator
946
947    Provides train/test indices to split time series data samples
948    that are observed at fixed time intervals, in train/test sets.
949    In each split, test indices must be higher than before, and thus shuffling
950    in cross validator is inappropriate.
951
952    This cross-validation object is a variation of :class:`KFold`.
953    In the kth split, it returns first k folds as train set and the
954    (k+1)th fold as test set.
955
956    Note that unlike standard cross-validation methods, successive
957    training sets are supersets of those that come before them.
958
959    Read more in the :ref:`User Guide <time_series_split>`.
960
961    .. versionadded:: 0.18
962
963    Parameters
964    ----------
965    n_splits : int, default=5
966        Number of splits. Must be at least 2.
967
968        .. versionchanged:: 0.22
969            ``n_splits`` default value changed from 3 to 5.
970
971    max_train_size : int, default=None
972        Maximum size for a single training set.
973
974    test_size : int, default=None
975        Used to limit the size of the test set. Defaults to
976        ``n_samples // (n_splits + 1)``, which is the maximum allowed value
977        with ``gap=0``.
978
979        .. versionadded:: 0.24
980
981    gap : int, default=0
982        Number of samples to exclude from the end of each train set before
983        the test set.
984
985        .. versionadded:: 0.24
986
987    Examples
988    --------
989    >>> import numpy as np
990    >>> from sklearn.model_selection import TimeSeriesSplit
991    >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [1, 2], [3, 4]])
992    >>> y = np.array([1, 2, 3, 4, 5, 6])
993    >>> tscv = TimeSeriesSplit()
994    >>> print(tscv)
995    TimeSeriesSplit(gap=0, max_train_size=None, n_splits=5, test_size=None)
996    >>> for train_index, test_index in tscv.split(X):
997    ...     print("TRAIN:", train_index, "TEST:", test_index)
998    ...     X_train, X_test = X[train_index], X[test_index]
999    ...     y_train, y_test = y[train_index], y[test_index]
1000    TRAIN: [0] TEST: [1]
1001    TRAIN: [0 1] TEST: [2]
1002    TRAIN: [0 1 2] TEST: [3]
1003    TRAIN: [0 1 2 3] TEST: [4]
1004    TRAIN: [0 1 2 3 4] TEST: [5]
1005    >>> # Fix test_size to 2 with 12 samples
1006    >>> X = np.random.randn(12, 2)
1007    >>> y = np.random.randint(0, 2, 12)
1008    >>> tscv = TimeSeriesSplit(n_splits=3, test_size=2)
1009    >>> for train_index, test_index in tscv.split(X):
1010    ...    print("TRAIN:", train_index, "TEST:", test_index)
1011    ...    X_train, X_test = X[train_index], X[test_index]
1012    ...    y_train, y_test = y[train_index], y[test_index]
1013    TRAIN: [0 1 2 3 4 5] TEST: [6 7]
1014    TRAIN: [0 1 2 3 4 5 6 7] TEST: [8 9]
1015    TRAIN: [0 1 2 3 4 5 6 7 8 9] TEST: [10 11]
1016    >>> # Add in a 2 period gap
1017    >>> tscv = TimeSeriesSplit(n_splits=3, test_size=2, gap=2)
1018    >>> for train_index, test_index in tscv.split(X):
1019    ...    print("TRAIN:", train_index, "TEST:", test_index)
1020    ...    X_train, X_test = X[train_index], X[test_index]
1021    ...    y_train, y_test = y[train_index], y[test_index]
1022    TRAIN: [0 1 2 3] TEST: [6 7]
1023    TRAIN: [0 1 2 3 4 5] TEST: [8 9]
1024    TRAIN: [0 1 2 3 4 5 6 7] TEST: [10 11]
1025
1026    Notes
1027    -----
1028    The training set has size ``i * n_samples // (n_splits + 1)
1029    + n_samples % (n_splits + 1)`` in the ``i`` th split,
1030    with a test set of size ``n_samples//(n_splits + 1)`` by default,
1031    where ``n_samples`` is the number of samples.
1032    """
1033
1034    def __init__(self, n_splits=5, *, max_train_size=None, test_size=None, gap=0):
1035        super().__init__(n_splits, shuffle=False, random_state=None)
1036        self.max_train_size = max_train_size
1037        self.test_size = test_size
1038        self.gap = gap
1039
1040    def split(self, X, y=None, groups=None):
1041        """Generate indices to split data into training and test set.
1042
1043        Parameters
1044        ----------
1045        X : array-like of shape (n_samples, n_features)
1046            Training data, where `n_samples` is the number of samples
1047            and `n_features` is the number of features.
1048
1049        y : array-like of shape (n_samples,)
1050            Always ignored, exists for compatibility.
1051
1052        groups : array-like of shape (n_samples,)
1053            Always ignored, exists for compatibility.
1054
1055        Yields
1056        ------
1057        train : ndarray
1058            The training set indices for that split.
1059
1060        test : ndarray
1061            The testing set indices for that split.
1062        """
1063        X, y, groups = indexable(X, y, groups)
1064        n_samples = _num_samples(X)
1065        n_splits = self.n_splits
1066        n_folds = n_splits + 1
1067        gap = self.gap
1068        test_size = (
1069            self.test_size if self.test_size is not None else n_samples // n_folds
1070        )
1071
1072        # Make sure we have enough samples for the given split parameters
1073        if n_folds > n_samples:
1074            raise ValueError(
1075                f"Cannot have number of folds={n_folds} greater"
1076                f" than the number of samples={n_samples}."
1077            )
1078        if n_samples - gap - (test_size * n_splits) <= 0:
1079            raise ValueError(
1080                f"Too many splits={n_splits} for number of samples"
1081                f"={n_samples} with test_size={test_size} and gap={gap}."
1082            )
1083
1084        indices = np.arange(n_samples)
1085        test_starts = range(n_samples - n_splits * test_size, n_samples, test_size)
1086
1087        for test_start in test_starts:
1088            train_end = test_start - gap
1089            if self.max_train_size and self.max_train_size < train_end:
1090                yield (
1091                    indices[train_end - self.max_train_size : train_end],
1092                    indices[test_start : test_start + test_size],
1093                )
1094            else:
1095                yield (
1096                    indices[:train_end],
1097                    indices[test_start : test_start + test_size],
1098                )
1099
1100
1101class LeaveOneGroupOut(BaseCrossValidator):
1102    """Leave One Group Out cross-validator
1103
1104    Provides train/test indices to split data according to a third-party
1105    provided group. This group information can be used to encode arbitrary
1106    domain specific stratifications of the samples as integers.
1107
1108    For instance the groups could be the year of collection of the samples
1109    and thus allow for cross-validation against time-based splits.
1110
1111    Read more in the :ref:`User Guide <leave_one_group_out>`.
1112
1113    Examples
1114    --------
1115    >>> import numpy as np
1116    >>> from sklearn.model_selection import LeaveOneGroupOut
1117    >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
1118    >>> y = np.array([1, 2, 1, 2])
1119    >>> groups = np.array([1, 1, 2, 2])
1120    >>> logo = LeaveOneGroupOut()
1121    >>> logo.get_n_splits(X, y, groups)
1122    2
1123    >>> logo.get_n_splits(groups=groups)  # 'groups' is always required
1124    2
1125    >>> print(logo)
1126    LeaveOneGroupOut()
1127    >>> for train_index, test_index in logo.split(X, y, groups):
1128    ...     print("TRAIN:", train_index, "TEST:", test_index)
1129    ...     X_train, X_test = X[train_index], X[test_index]
1130    ...     y_train, y_test = y[train_index], y[test_index]
1131    ...     print(X_train, X_test, y_train, y_test)
1132    TRAIN: [2 3] TEST: [0 1]
1133    [[5 6]
1134     [7 8]] [[1 2]
1135     [3 4]] [1 2] [1 2]
1136    TRAIN: [0 1] TEST: [2 3]
1137    [[1 2]
1138     [3 4]] [[5 6]
1139     [7 8]] [1 2] [1 2]
1140
1141    """
1142
1143    def _iter_test_masks(self, X, y, groups):
1144        if groups is None:
1145            raise ValueError("The 'groups' parameter should not be None.")
1146        # We make a copy of groups to avoid side-effects during iteration
1147        groups = check_array(groups, copy=True, ensure_2d=False, dtype=None)
1148        unique_groups = np.unique(groups)
1149        if len(unique_groups) <= 1:
1150            raise ValueError(
1151                "The groups parameter contains fewer than 2 unique groups "
1152                "(%s). LeaveOneGroupOut expects at least 2." % unique_groups
1153            )
1154        for i in unique_groups:
1155            yield groups == i
1156
1157    def get_n_splits(self, X=None, y=None, groups=None):
1158        """Returns the number of splitting iterations in the cross-validator
1159
1160        Parameters
1161        ----------
1162        X : object
1163            Always ignored, exists for compatibility.
1164
1165        y : object
1166            Always ignored, exists for compatibility.
1167
1168        groups : array-like of shape (n_samples,)
1169            Group labels for the samples used while splitting the dataset into
1170            train/test set. This 'groups' parameter must always be specified to
1171            calculate the number of splits, though the other parameters can be
1172            omitted.
1173
1174        Returns
1175        -------
1176        n_splits : int
1177            Returns the number of splitting iterations in the cross-validator.
1178        """
1179        if groups is None:
1180            raise ValueError("The 'groups' parameter should not be None.")
1181        groups = check_array(groups, ensure_2d=False, dtype=None)
1182        return len(np.unique(groups))
1183
1184    def split(self, X, y=None, groups=None):
1185        """Generate indices to split data into training and test set.
1186
1187        Parameters
1188        ----------
1189        X : array-like of shape (n_samples, n_features)
1190            Training data, where `n_samples` is the number of samples
1191            and `n_features` is the number of features.
1192
1193        y : array-like of shape (n_samples,), default=None
1194            The target variable for supervised learning problems.
1195
1196        groups : array-like of shape (n_samples,)
1197            Group labels for the samples used while splitting the dataset into
1198            train/test set.
1199
1200        Yields
1201        ------
1202        train : ndarray
1203            The training set indices for that split.
1204
1205        test : ndarray
1206            The testing set indices for that split.
1207        """
1208        return super().split(X, y, groups)
1209
1210
1211class LeavePGroupsOut(BaseCrossValidator):
1212    """Leave P Group(s) Out cross-validator
1213
1214    Provides train/test indices to split data according to a third-party
1215    provided group. This group information can be used to encode arbitrary
1216    domain specific stratifications of the samples as integers.
1217
1218    For instance the groups could be the year of collection of the samples
1219    and thus allow for cross-validation against time-based splits.
1220
1221    The difference between LeavePGroupsOut and LeaveOneGroupOut is that
1222    the former builds the test sets with all the samples assigned to
1223    ``p`` different values of the groups while the latter uses samples
1224    all assigned the same groups.
1225
1226    Read more in the :ref:`User Guide <leave_p_groups_out>`.
1227
1228    Parameters
1229    ----------
1230    n_groups : int
1231        Number of groups (``p``) to leave out in the test split.
1232
1233    Examples
1234    --------
1235    >>> import numpy as np
1236    >>> from sklearn.model_selection import LeavePGroupsOut
1237    >>> X = np.array([[1, 2], [3, 4], [5, 6]])
1238    >>> y = np.array([1, 2, 1])
1239    >>> groups = np.array([1, 2, 3])
1240    >>> lpgo = LeavePGroupsOut(n_groups=2)
1241    >>> lpgo.get_n_splits(X, y, groups)
1242    3
1243    >>> lpgo.get_n_splits(groups=groups)  # 'groups' is always required
1244    3
1245    >>> print(lpgo)
1246    LeavePGroupsOut(n_groups=2)
1247    >>> for train_index, test_index in lpgo.split(X, y, groups):
1248    ...     print("TRAIN:", train_index, "TEST:", test_index)
1249    ...     X_train, X_test = X[train_index], X[test_index]
1250    ...     y_train, y_test = y[train_index], y[test_index]
1251    ...     print(X_train, X_test, y_train, y_test)
1252    TRAIN: [2] TEST: [0 1]
1253    [[5 6]] [[1 2]
1254     [3 4]] [1] [1 2]
1255    TRAIN: [1] TEST: [0 2]
1256    [[3 4]] [[1 2]
1257     [5 6]] [2] [1 1]
1258    TRAIN: [0] TEST: [1 2]
1259    [[1 2]] [[3 4]
1260     [5 6]] [1] [2 1]
1261
1262    See Also
1263    --------
1264    GroupKFold : K-fold iterator variant with non-overlapping groups.
1265    """
1266
1267    def __init__(self, n_groups):
1268        self.n_groups = n_groups
1269
1270    def _iter_test_masks(self, X, y, groups):
1271        if groups is None:
1272            raise ValueError("The 'groups' parameter should not be None.")
1273        groups = check_array(groups, copy=True, ensure_2d=False, dtype=None)
1274        unique_groups = np.unique(groups)
1275        if self.n_groups >= len(unique_groups):
1276            raise ValueError(
1277                "The groups parameter contains fewer than (or equal to) "
1278                "n_groups (%d) numbers of unique groups (%s). LeavePGroupsOut "
1279                "expects that at least n_groups + 1 (%d) unique groups be "
1280                "present" % (self.n_groups, unique_groups, self.n_groups + 1)
1281            )
1282        combi = combinations(range(len(unique_groups)), self.n_groups)
1283        for indices in combi:
1284            test_index = np.zeros(_num_samples(X), dtype=bool)
1285            for l in unique_groups[np.array(indices)]:
1286                test_index[groups == l] = True
1287            yield test_index
1288
1289    def get_n_splits(self, X=None, y=None, groups=None):
1290        """Returns the number of splitting iterations in the cross-validator
1291
1292        Parameters
1293        ----------
1294        X : object
1295            Always ignored, exists for compatibility.
1296
1297        y : object
1298            Always ignored, exists for compatibility.
1299
1300        groups : array-like of shape (n_samples,)
1301            Group labels for the samples used while splitting the dataset into
1302            train/test set. This 'groups' parameter must always be specified to
1303            calculate the number of splits, though the other parameters can be
1304            omitted.
1305
1306        Returns
1307        -------
1308        n_splits : int
1309            Returns the number of splitting iterations in the cross-validator.
1310        """
1311        if groups is None:
1312            raise ValueError("The 'groups' parameter should not be None.")
1313        groups = check_array(groups, ensure_2d=False, dtype=None)
1314        return int(comb(len(np.unique(groups)), self.n_groups, exact=True))
1315
1316    def split(self, X, y=None, groups=None):
1317        """Generate indices to split data into training and test set.
1318
1319        Parameters
1320        ----------
1321        X : array-like of shape (n_samples, n_features)
1322            Training data, where `n_samples` is the number of samples
1323            and `n_features` is the number of features.
1324
1325        y : array-like of shape (n_samples,), default=None
1326            The target variable for supervised learning problems.
1327
1328        groups : array-like of shape (n_samples,)
1329            Group labels for the samples used while splitting the dataset into
1330            train/test set.
1331
1332        Yields
1333        ------
1334        train : ndarray
1335            The training set indices for that split.
1336
1337        test : ndarray
1338            The testing set indices for that split.
1339        """
1340        return super().split(X, y, groups)
1341
1342
1343class _RepeatedSplits(metaclass=ABCMeta):
1344    """Repeated splits for an arbitrary randomized CV splitter.
1345
1346    Repeats splits for cross-validators n times with different randomization
1347    in each repetition.
1348
1349    Parameters
1350    ----------
1351    cv : callable
1352        Cross-validator class.
1353
1354    n_repeats : int, default=10
1355        Number of times cross-validator needs to be repeated.
1356
1357    random_state : int, RandomState instance or None, default=None
1358        Passes `random_state` to the arbitrary repeating cross validator.
1359        Pass an int for reproducible output across multiple function calls.
1360        See :term:`Glossary <random_state>`.
1361
1362    **cvargs : additional params
1363        Constructor parameters for cv. Must not contain random_state
1364        and shuffle.
1365    """
1366
1367    def __init__(self, cv, *, n_repeats=10, random_state=None, **cvargs):
1368        if not isinstance(n_repeats, numbers.Integral):
1369            raise ValueError("Number of repetitions must be of Integral type.")
1370
1371        if n_repeats <= 0:
1372            raise ValueError("Number of repetitions must be greater than 0.")
1373
1374        if any(key in cvargs for key in ("random_state", "shuffle")):
1375            raise ValueError("cvargs must not contain random_state or shuffle.")
1376
1377        self.cv = cv
1378        self.n_repeats = n_repeats
1379        self.random_state = random_state
1380        self.cvargs = cvargs
1381
1382    def split(self, X, y=None, groups=None):
1383        """Generates indices to split data into training and test set.
1384
1385        Parameters
1386        ----------
1387        X : array-like of shape (n_samples, n_features)
1388            Training data, where `n_samples` is the number of samples
1389            and `n_features` is the number of features.
1390
1391        y : array-like of shape (n_samples,)
1392            The target variable for supervised learning problems.
1393
1394        groups : array-like of shape (n_samples,), default=None
1395            Group labels for the samples used while splitting the dataset into
1396            train/test set.
1397
1398        Yields
1399        ------
1400        train : ndarray
1401            The training set indices for that split.
1402
1403        test : ndarray
1404            The testing set indices for that split.
1405        """
1406        n_repeats = self.n_repeats
1407        rng = check_random_state(self.random_state)
1408
1409        for idx in range(n_repeats):
1410            cv = self.cv(random_state=rng, shuffle=True, **self.cvargs)
1411            for train_index, test_index in cv.split(X, y, groups):
1412                yield train_index, test_index
1413
1414    def get_n_splits(self, X=None, y=None, groups=None):
1415        """Returns the number of splitting iterations in the cross-validator
1416
1417        Parameters
1418        ----------
1419        X : object
1420            Always ignored, exists for compatibility.
1421            ``np.zeros(n_samples)`` may be used as a placeholder.
1422
1423        y : object
1424            Always ignored, exists for compatibility.
1425            ``np.zeros(n_samples)`` may be used as a placeholder.
1426
1427        groups : array-like of shape (n_samples,), default=None
1428            Group labels for the samples used while splitting the dataset into
1429            train/test set.
1430
1431        Returns
1432        -------
1433        n_splits : int
1434            Returns the number of splitting iterations in the cross-validator.
1435        """
1436        rng = check_random_state(self.random_state)
1437        cv = self.cv(random_state=rng, shuffle=True, **self.cvargs)
1438        return cv.get_n_splits(X, y, groups) * self.n_repeats
1439
1440    def __repr__(self):
1441        return _build_repr(self)
1442
1443
1444class RepeatedKFold(_RepeatedSplits):
1445    """Repeated K-Fold cross validator.
1446
1447    Repeats K-Fold n times with different randomization in each repetition.
1448
1449    Read more in the :ref:`User Guide <repeated_k_fold>`.
1450
1451    Parameters
1452    ----------
1453    n_splits : int, default=5
1454        Number of folds. Must be at least 2.
1455
1456    n_repeats : int, default=10
1457        Number of times cross-validator needs to be repeated.
1458
1459    random_state : int, RandomState instance or None, default=None
1460        Controls the randomness of each repeated cross-validation instance.
1461        Pass an int for reproducible output across multiple function calls.
1462        See :term:`Glossary <random_state>`.
1463
1464    Examples
1465    --------
1466    >>> import numpy as np
1467    >>> from sklearn.model_selection import RepeatedKFold
1468    >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
1469    >>> y = np.array([0, 0, 1, 1])
1470    >>> rkf = RepeatedKFold(n_splits=2, n_repeats=2, random_state=2652124)
1471    >>> for train_index, test_index in rkf.split(X):
1472    ...     print("TRAIN:", train_index, "TEST:", test_index)
1473    ...     X_train, X_test = X[train_index], X[test_index]
1474    ...     y_train, y_test = y[train_index], y[test_index]
1475    ...
1476    TRAIN: [0 1] TEST: [2 3]
1477    TRAIN: [2 3] TEST: [0 1]
1478    TRAIN: [1 2] TEST: [0 3]
1479    TRAIN: [0 3] TEST: [1 2]
1480
1481    Notes
1482    -----
1483    Randomized CV splitters may return different results for each call of
1484    split. You can make the results identical by setting `random_state`
1485    to an integer.
1486
1487    See Also
1488    --------
1489    RepeatedStratifiedKFold : Repeats Stratified K-Fold n times.
1490    """
1491
1492    def __init__(self, *, n_splits=5, n_repeats=10, random_state=None):
1493        super().__init__(
1494            KFold, n_repeats=n_repeats, random_state=random_state, n_splits=n_splits
1495        )
1496
1497
1498class RepeatedStratifiedKFold(_RepeatedSplits):
1499    """Repeated Stratified K-Fold cross validator.
1500
1501    Repeats Stratified K-Fold n times with different randomization in each
1502    repetition.
1503
1504    Read more in the :ref:`User Guide <repeated_k_fold>`.
1505
1506    Parameters
1507    ----------
1508    n_splits : int, default=5
1509        Number of folds. Must be at least 2.
1510
1511    n_repeats : int, default=10
1512        Number of times cross-validator needs to be repeated.
1513
1514    random_state : int, RandomState instance or None, default=None
1515        Controls the generation of the random states for each repetition.
1516        Pass an int for reproducible output across multiple function calls.
1517        See :term:`Glossary <random_state>`.
1518
1519    Examples
1520    --------
1521    >>> import numpy as np
1522    >>> from sklearn.model_selection import RepeatedStratifiedKFold
1523    >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
1524    >>> y = np.array([0, 0, 1, 1])
1525    >>> rskf = RepeatedStratifiedKFold(n_splits=2, n_repeats=2,
1526    ...     random_state=36851234)
1527    >>> for train_index, test_index in rskf.split(X, y):
1528    ...     print("TRAIN:", train_index, "TEST:", test_index)
1529    ...     X_train, X_test = X[train_index], X[test_index]
1530    ...     y_train, y_test = y[train_index], y[test_index]
1531    ...
1532    TRAIN: [1 2] TEST: [0 3]
1533    TRAIN: [0 3] TEST: [1 2]
1534    TRAIN: [1 3] TEST: [0 2]
1535    TRAIN: [0 2] TEST: [1 3]
1536
1537    Notes
1538    -----
1539    Randomized CV splitters may return different results for each call of
1540    split. You can make the results identical by setting `random_state`
1541    to an integer.
1542
1543    See Also
1544    --------
1545    RepeatedKFold : Repeats K-Fold n times.
1546    """
1547
1548    def __init__(self, *, n_splits=5, n_repeats=10, random_state=None):
1549        super().__init__(
1550            StratifiedKFold,
1551            n_repeats=n_repeats,
1552            random_state=random_state,
1553            n_splits=n_splits,
1554        )
1555
1556
1557class BaseShuffleSplit(metaclass=ABCMeta):
1558    """Base class for ShuffleSplit and StratifiedShuffleSplit"""
1559
1560    def __init__(
1561        self, n_splits=10, *, test_size=None, train_size=None, random_state=None
1562    ):
1563        self.n_splits = n_splits
1564        self.test_size = test_size
1565        self.train_size = train_size
1566        self.random_state = random_state
1567        self._default_test_size = 0.1
1568
1569    def split(self, X, y=None, groups=None):
1570        """Generate indices to split data into training and test set.
1571
1572        Parameters
1573        ----------
1574        X : array-like of shape (n_samples, n_features)
1575            Training data, where `n_samples` is the number of samples
1576            and `n_features` is the number of features.
1577
1578        y : array-like of shape (n_samples,)
1579            The target variable for supervised learning problems.
1580
1581        groups : array-like of shape (n_samples,), default=None
1582            Group labels for the samples used while splitting the dataset into
1583            train/test set.
1584
1585        Yields
1586        ------
1587        train : ndarray
1588            The training set indices for that split.
1589
1590        test : ndarray
1591            The testing set indices for that split.
1592
1593        Notes
1594        -----
1595        Randomized CV splitters may return different results for each call of
1596        split. You can make the results identical by setting `random_state`
1597        to an integer.
1598        """
1599        X, y, groups = indexable(X, y, groups)
1600        for train, test in self._iter_indices(X, y, groups):
1601            yield train, test
1602
1603    @abstractmethod
1604    def _iter_indices(self, X, y=None, groups=None):
1605        """Generate (train, test) indices"""
1606
1607    def get_n_splits(self, X=None, y=None, groups=None):
1608        """Returns the number of splitting iterations in the cross-validator
1609
1610        Parameters
1611        ----------
1612        X : object
1613            Always ignored, exists for compatibility.
1614
1615        y : object
1616            Always ignored, exists for compatibility.
1617
1618        groups : object
1619            Always ignored, exists for compatibility.
1620
1621        Returns
1622        -------
1623        n_splits : int
1624            Returns the number of splitting iterations in the cross-validator.
1625        """
1626        return self.n_splits
1627
1628    def __repr__(self):
1629        return _build_repr(self)
1630
1631
1632class ShuffleSplit(BaseShuffleSplit):
1633    """Random permutation cross-validator
1634
1635    Yields indices to split data into training and test sets.
1636
1637    Note: contrary to other cross-validation strategies, random splits
1638    do not guarantee that all folds will be different, although this is
1639    still very likely for sizeable datasets.
1640
1641    Read more in the :ref:`User Guide <ShuffleSplit>`.
1642
1643    Parameters
1644    ----------
1645    n_splits : int, default=10
1646        Number of re-shuffling & splitting iterations.
1647
1648    test_size : float or int, default=None
1649        If float, should be between 0.0 and 1.0 and represent the proportion
1650        of the dataset to include in the test split. If int, represents the
1651        absolute number of test samples. If None, the value is set to the
1652        complement of the train size. If ``train_size`` is also None, it will
1653        be set to 0.1.
1654
1655    train_size : float or int, default=None
1656        If float, should be between 0.0 and 1.0 and represent the
1657        proportion of the dataset to include in the train split. If
1658        int, represents the absolute number of train samples. If None,
1659        the value is automatically set to the complement of the test size.
1660
1661    random_state : int, RandomState instance or None, default=None
1662        Controls the randomness of the training and testing indices produced.
1663        Pass an int for reproducible output across multiple function calls.
1664        See :term:`Glossary <random_state>`.
1665
1666    Examples
1667    --------
1668    >>> import numpy as np
1669    >>> from sklearn.model_selection import ShuffleSplit
1670    >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [3, 4], [5, 6]])
1671    >>> y = np.array([1, 2, 1, 2, 1, 2])
1672    >>> rs = ShuffleSplit(n_splits=5, test_size=.25, random_state=0)
1673    >>> rs.get_n_splits(X)
1674    5
1675    >>> print(rs)
1676    ShuffleSplit(n_splits=5, random_state=0, test_size=0.25, train_size=None)
1677    >>> for train_index, test_index in rs.split(X):
1678    ...     print("TRAIN:", train_index, "TEST:", test_index)
1679    TRAIN: [1 3 0 4] TEST: [5 2]
1680    TRAIN: [4 0 2 5] TEST: [1 3]
1681    TRAIN: [1 2 4 0] TEST: [3 5]
1682    TRAIN: [3 4 1 0] TEST: [5 2]
1683    TRAIN: [3 5 1 0] TEST: [2 4]
1684    >>> rs = ShuffleSplit(n_splits=5, train_size=0.5, test_size=.25,
1685    ...                   random_state=0)
1686    >>> for train_index, test_index in rs.split(X):
1687    ...     print("TRAIN:", train_index, "TEST:", test_index)
1688    TRAIN: [1 3 0] TEST: [5 2]
1689    TRAIN: [4 0 2] TEST: [1 3]
1690    TRAIN: [1 2 4] TEST: [3 5]
1691    TRAIN: [3 4 1] TEST: [5 2]
1692    TRAIN: [3 5 1] TEST: [2 4]
1693    """
1694
1695    def __init__(
1696        self, n_splits=10, *, test_size=None, train_size=None, random_state=None
1697    ):
1698        super().__init__(
1699            n_splits=n_splits,
1700            test_size=test_size,
1701            train_size=train_size,
1702            random_state=random_state,
1703        )
1704        self._default_test_size = 0.1
1705
1706    def _iter_indices(self, X, y=None, groups=None):
1707        n_samples = _num_samples(X)
1708        n_train, n_test = _validate_shuffle_split(
1709            n_samples,
1710            self.test_size,
1711            self.train_size,
1712            default_test_size=self._default_test_size,
1713        )
1714
1715        rng = check_random_state(self.random_state)
1716        for i in range(self.n_splits):
1717            # random partition
1718            permutation = rng.permutation(n_samples)
1719            ind_test = permutation[:n_test]
1720            ind_train = permutation[n_test : (n_test + n_train)]
1721            yield ind_train, ind_test
1722
1723
1724class GroupShuffleSplit(ShuffleSplit):
1725    """Shuffle-Group(s)-Out cross-validation iterator
1726
1727    Provides randomized train/test indices to split data according to a
1728    third-party provided group. This group information can be used to encode
1729    arbitrary domain specific stratifications of the samples as integers.
1730
1731    For instance the groups could be the year of collection of the samples
1732    and thus allow for cross-validation against time-based splits.
1733
1734    The difference between LeavePGroupsOut and GroupShuffleSplit is that
1735    the former generates splits using all subsets of size ``p`` unique groups,
1736    whereas GroupShuffleSplit generates a user-determined number of random
1737    test splits, each with a user-determined fraction of unique groups.
1738
1739    For example, a less computationally intensive alternative to
1740    ``LeavePGroupsOut(p=10)`` would be
1741    ``GroupShuffleSplit(test_size=10, n_splits=100)``.
1742
1743    Note: The parameters ``test_size`` and ``train_size`` refer to groups, and
1744    not to samples, as in ShuffleSplit.
1745
1746    Read more in the :ref:`User Guide <group_shuffle_split>`.
1747
1748    Parameters
1749    ----------
1750    n_splits : int, default=5
1751        Number of re-shuffling & splitting iterations.
1752
1753    test_size : float, int, default=0.2
1754        If float, should be between 0.0 and 1.0 and represent the proportion
1755        of groups to include in the test split (rounded up). If int,
1756        represents the absolute number of test groups. If None, the value is
1757        set to the complement of the train size.
1758        The default will change in version 0.21. It will remain 0.2 only
1759        if ``train_size`` is unspecified, otherwise it will complement
1760        the specified ``train_size``.
1761
1762    train_size : float or int, default=None
1763        If float, should be between 0.0 and 1.0 and represent the
1764        proportion of the groups to include in the train split. If
1765        int, represents the absolute number of train groups. If None,
1766        the value is automatically set to the complement of the test size.
1767
1768    random_state : int, RandomState instance or None, default=None
1769        Controls the randomness of the training and testing indices produced.
1770        Pass an int for reproducible output across multiple function calls.
1771        See :term:`Glossary <random_state>`.
1772
1773    Examples
1774    --------
1775    >>> import numpy as np
1776    >>> from sklearn.model_selection import GroupShuffleSplit
1777    >>> X = np.ones(shape=(8, 2))
1778    >>> y = np.ones(shape=(8, 1))
1779    >>> groups = np.array([1, 1, 2, 2, 2, 3, 3, 3])
1780    >>> print(groups.shape)
1781    (8,)
1782    >>> gss = GroupShuffleSplit(n_splits=2, train_size=.7, random_state=42)
1783    >>> gss.get_n_splits()
1784    2
1785    >>> for train_idx, test_idx in gss.split(X, y, groups):
1786    ...     print("TRAIN:", train_idx, "TEST:", test_idx)
1787    TRAIN: [2 3 4 5 6 7] TEST: [0 1]
1788    TRAIN: [0 1 5 6 7] TEST: [2 3 4]
1789    """
1790
1791    def __init__(
1792        self, n_splits=5, *, test_size=None, train_size=None, random_state=None
1793    ):
1794        super().__init__(
1795            n_splits=n_splits,
1796            test_size=test_size,
1797            train_size=train_size,
1798            random_state=random_state,
1799        )
1800        self._default_test_size = 0.2
1801
1802    def _iter_indices(self, X, y, groups):
1803        if groups is None:
1804            raise ValueError("The 'groups' parameter should not be None.")
1805        groups = check_array(groups, ensure_2d=False, dtype=None)
1806        classes, group_indices = np.unique(groups, return_inverse=True)
1807        for group_train, group_test in super()._iter_indices(X=classes):
1808            # these are the indices of classes in the partition
1809            # invert them into data indices
1810
1811            train = np.flatnonzero(np.in1d(group_indices, group_train))
1812            test = np.flatnonzero(np.in1d(group_indices, group_test))
1813
1814            yield train, test
1815
1816    def split(self, X, y=None, groups=None):
1817        """Generate indices to split data into training and test set.
1818
1819        Parameters
1820        ----------
1821        X : array-like of shape (n_samples, n_features)
1822            Training data, where `n_samples` is the number of samples
1823            and `n_features` is the number of features.
1824
1825        y : array-like of shape (n_samples,), default=None
1826            The target variable for supervised learning problems.
1827
1828        groups : array-like of shape (n_samples,)
1829            Group labels for the samples used while splitting the dataset into
1830            train/test set.
1831
1832        Yields
1833        ------
1834        train : ndarray
1835            The training set indices for that split.
1836
1837        test : ndarray
1838            The testing set indices for that split.
1839
1840        Notes
1841        -----
1842        Randomized CV splitters may return different results for each call of
1843        split. You can make the results identical by setting `random_state`
1844        to an integer.
1845        """
1846        return super().split(X, y, groups)
1847
1848
1849class StratifiedShuffleSplit(BaseShuffleSplit):
1850    """Stratified ShuffleSplit cross-validator
1851
1852    Provides train/test indices to split data in train/test sets.
1853
1854    This cross-validation object is a merge of StratifiedKFold and
1855    ShuffleSplit, which returns stratified randomized folds. The folds
1856    are made by preserving the percentage of samples for each class.
1857
1858    Note: like the ShuffleSplit strategy, stratified random splits
1859    do not guarantee that all folds will be different, although this is
1860    still very likely for sizeable datasets.
1861
1862    Read more in the :ref:`User Guide <stratified_shuffle_split>`.
1863
1864    Parameters
1865    ----------
1866    n_splits : int, default=10
1867        Number of re-shuffling & splitting iterations.
1868
1869    test_size : float or int, default=None
1870        If float, should be between 0.0 and 1.0 and represent the proportion
1871        of the dataset to include in the test split. If int, represents the
1872        absolute number of test samples. If None, the value is set to the
1873        complement of the train size. If ``train_size`` is also None, it will
1874        be set to 0.1.
1875
1876    train_size : float or int, default=None
1877        If float, should be between 0.0 and 1.0 and represent the
1878        proportion of the dataset to include in the train split. If
1879        int, represents the absolute number of train samples. If None,
1880        the value is automatically set to the complement of the test size.
1881
1882    random_state : int, RandomState instance or None, default=None
1883        Controls the randomness of the training and testing indices produced.
1884        Pass an int for reproducible output across multiple function calls.
1885        See :term:`Glossary <random_state>`.
1886
1887    Examples
1888    --------
1889    >>> import numpy as np
1890    >>> from sklearn.model_selection import StratifiedShuffleSplit
1891    >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [1, 2], [3, 4]])
1892    >>> y = np.array([0, 0, 0, 1, 1, 1])
1893    >>> sss = StratifiedShuffleSplit(n_splits=5, test_size=0.5, random_state=0)
1894    >>> sss.get_n_splits(X, y)
1895    5
1896    >>> print(sss)
1897    StratifiedShuffleSplit(n_splits=5, random_state=0, ...)
1898    >>> for train_index, test_index in sss.split(X, y):
1899    ...     print("TRAIN:", train_index, "TEST:", test_index)
1900    ...     X_train, X_test = X[train_index], X[test_index]
1901    ...     y_train, y_test = y[train_index], y[test_index]
1902    TRAIN: [5 2 3] TEST: [4 1 0]
1903    TRAIN: [5 1 4] TEST: [0 2 3]
1904    TRAIN: [5 0 2] TEST: [4 3 1]
1905    TRAIN: [4 1 0] TEST: [2 3 5]
1906    TRAIN: [0 5 1] TEST: [3 4 2]
1907    """
1908
1909    def __init__(
1910        self, n_splits=10, *, test_size=None, train_size=None, random_state=None
1911    ):
1912        super().__init__(
1913            n_splits=n_splits,
1914            test_size=test_size,
1915            train_size=train_size,
1916            random_state=random_state,
1917        )
1918        self._default_test_size = 0.1
1919
1920    def _iter_indices(self, X, y, groups=None):
1921        n_samples = _num_samples(X)
1922        y = check_array(y, ensure_2d=False, dtype=None)
1923        n_train, n_test = _validate_shuffle_split(
1924            n_samples,
1925            self.test_size,
1926            self.train_size,
1927            default_test_size=self._default_test_size,
1928        )
1929
1930        if y.ndim == 2:
1931            # for multi-label y, map each distinct row to a string repr
1932            # using join because str(row) uses an ellipsis if len(row) > 1000
1933            y = np.array([" ".join(row.astype("str")) for row in y])
1934
1935        classes, y_indices = np.unique(y, return_inverse=True)
1936        n_classes = classes.shape[0]
1937
1938        class_counts = np.bincount(y_indices)
1939        if np.min(class_counts) < 2:
1940            raise ValueError(
1941                "The least populated class in y has only 1"
1942                " member, which is too few. The minimum"
1943                " number of groups for any class cannot"
1944                " be less than 2."
1945            )
1946
1947        if n_train < n_classes:
1948            raise ValueError(
1949                "The train_size = %d should be greater or "
1950                "equal to the number of classes = %d" % (n_train, n_classes)
1951            )
1952        if n_test < n_classes:
1953            raise ValueError(
1954                "The test_size = %d should be greater or "
1955                "equal to the number of classes = %d" % (n_test, n_classes)
1956            )
1957
1958        # Find the sorted list of instances for each class:
1959        # (np.unique above performs a sort, so code is O(n logn) already)
1960        class_indices = np.split(
1961            np.argsort(y_indices, kind="mergesort"), np.cumsum(class_counts)[:-1]
1962        )
1963
1964        rng = check_random_state(self.random_state)
1965
1966        for _ in range(self.n_splits):
1967            # if there are ties in the class-counts, we want
1968            # to make sure to break them anew in each iteration
1969            n_i = _approximate_mode(class_counts, n_train, rng)
1970            class_counts_remaining = class_counts - n_i
1971            t_i = _approximate_mode(class_counts_remaining, n_test, rng)
1972
1973            train = []
1974            test = []
1975
1976            for i in range(n_classes):
1977                permutation = rng.permutation(class_counts[i])
1978                perm_indices_class_i = class_indices[i].take(permutation, mode="clip")
1979
1980                train.extend(perm_indices_class_i[: n_i[i]])
1981                test.extend(perm_indices_class_i[n_i[i] : n_i[i] + t_i[i]])
1982
1983            train = rng.permutation(train)
1984            test = rng.permutation(test)
1985
1986            yield train, test
1987
1988    def split(self, X, y, groups=None):
1989        """Generate indices to split data into training and test set.
1990
1991        Parameters
1992        ----------
1993        X : array-like of shape (n_samples, n_features)
1994            Training data, where `n_samples` is the number of samples
1995            and `n_features` is the number of features.
1996
1997            Note that providing ``y`` is sufficient to generate the splits and
1998            hence ``np.zeros(n_samples)`` may be used as a placeholder for
1999            ``X`` instead of actual training data.
2000
2001        y : array-like of shape (n_samples,) or (n_samples, n_labels)
2002            The target variable for supervised learning problems.
2003            Stratification is done based on the y labels.
2004
2005        groups : object
2006            Always ignored, exists for compatibility.
2007
2008        Yields
2009        ------
2010        train : ndarray
2011            The training set indices for that split.
2012
2013        test : ndarray
2014            The testing set indices for that split.
2015
2016        Notes
2017        -----
2018        Randomized CV splitters may return different results for each call of
2019        split. You can make the results identical by setting `random_state`
2020        to an integer.
2021        """
2022        y = check_array(y, ensure_2d=False, dtype=None)
2023        return super().split(X, y, groups)
2024
2025
2026def _validate_shuffle_split(n_samples, test_size, train_size, default_test_size=None):
2027    """
2028    Validation helper to check if the test/test sizes are meaningful wrt to the
2029    size of the data (n_samples)
2030    """
2031    if test_size is None and train_size is None:
2032        test_size = default_test_size
2033
2034    test_size_type = np.asarray(test_size).dtype.kind
2035    train_size_type = np.asarray(train_size).dtype.kind
2036
2037    if (
2038        test_size_type == "i"
2039        and (test_size >= n_samples or test_size <= 0)
2040        or test_size_type == "f"
2041        and (test_size <= 0 or test_size >= 1)
2042    ):
2043        raise ValueError(
2044            "test_size={0} should be either positive and smaller"
2045            " than the number of samples {1} or a float in the "
2046            "(0, 1) range".format(test_size, n_samples)
2047        )
2048
2049    if (
2050        train_size_type == "i"
2051        and (train_size >= n_samples or train_size <= 0)
2052        or train_size_type == "f"
2053        and (train_size <= 0 or train_size >= 1)
2054    ):
2055        raise ValueError(
2056            "train_size={0} should be either positive and smaller"
2057            " than the number of samples {1} or a float in the "
2058            "(0, 1) range".format(train_size, n_samples)
2059        )
2060
2061    if train_size is not None and train_size_type not in ("i", "f"):
2062        raise ValueError("Invalid value for train_size: {}".format(train_size))
2063    if test_size is not None and test_size_type not in ("i", "f"):
2064        raise ValueError("Invalid value for test_size: {}".format(test_size))
2065
2066    if train_size_type == "f" and test_size_type == "f" and train_size + test_size > 1:
2067        raise ValueError(
2068            "The sum of test_size and train_size = {}, should be in the (0, 1)"
2069            " range. Reduce test_size and/or train_size.".format(train_size + test_size)
2070        )
2071
2072    if test_size_type == "f":
2073        n_test = ceil(test_size * n_samples)
2074    elif test_size_type == "i":
2075        n_test = float(test_size)
2076
2077    if train_size_type == "f":
2078        n_train = floor(train_size * n_samples)
2079    elif train_size_type == "i":
2080        n_train = float(train_size)
2081
2082    if train_size is None:
2083        n_train = n_samples - n_test
2084    elif test_size is None:
2085        n_test = n_samples - n_train
2086
2087    if n_train + n_test > n_samples:
2088        raise ValueError(
2089            "The sum of train_size and test_size = %d, "
2090            "should be smaller than the number of "
2091            "samples %d. Reduce test_size and/or "
2092            "train_size." % (n_train + n_test, n_samples)
2093        )
2094
2095    n_train, n_test = int(n_train), int(n_test)
2096
2097    if n_train == 0:
2098        raise ValueError(
2099            "With n_samples={}, test_size={} and train_size={}, the "
2100            "resulting train set will be empty. Adjust any of the "
2101            "aforementioned parameters.".format(n_samples, test_size, train_size)
2102        )
2103
2104    return n_train, n_test
2105
2106
2107class PredefinedSplit(BaseCrossValidator):
2108    """Predefined split cross-validator
2109
2110    Provides train/test indices to split data into train/test sets using a
2111    predefined scheme specified by the user with the ``test_fold`` parameter.
2112
2113    Read more in the :ref:`User Guide <predefined_split>`.
2114
2115    .. versionadded:: 0.16
2116
2117    Parameters
2118    ----------
2119    test_fold : array-like of shape (n_samples,)
2120        The entry ``test_fold[i]`` represents the index of the test set that
2121        sample ``i`` belongs to. It is possible to exclude sample ``i`` from
2122        any test set (i.e. include sample ``i`` in every training set) by
2123        setting ``test_fold[i]`` equal to -1.
2124
2125    Examples
2126    --------
2127    >>> import numpy as np
2128    >>> from sklearn.model_selection import PredefinedSplit
2129    >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
2130    >>> y = np.array([0, 0, 1, 1])
2131    >>> test_fold = [0, 1, -1, 1]
2132    >>> ps = PredefinedSplit(test_fold)
2133    >>> ps.get_n_splits()
2134    2
2135    >>> print(ps)
2136    PredefinedSplit(test_fold=array([ 0,  1, -1,  1]))
2137    >>> for train_index, test_index in ps.split():
2138    ...     print("TRAIN:", train_index, "TEST:", test_index)
2139    ...     X_train, X_test = X[train_index], X[test_index]
2140    ...     y_train, y_test = y[train_index], y[test_index]
2141    TRAIN: [1 2 3] TEST: [0]
2142    TRAIN: [0 2] TEST: [1 3]
2143    """
2144
2145    def __init__(self, test_fold):
2146        self.test_fold = np.array(test_fold, dtype=int)
2147        self.test_fold = column_or_1d(self.test_fold)
2148        self.unique_folds = np.unique(self.test_fold)
2149        self.unique_folds = self.unique_folds[self.unique_folds != -1]
2150
2151    def split(self, X=None, y=None, groups=None):
2152        """Generate indices to split data into training and test set.
2153
2154        Parameters
2155        ----------
2156        X : object
2157            Always ignored, exists for compatibility.
2158
2159        y : object
2160            Always ignored, exists for compatibility.
2161
2162        groups : object
2163            Always ignored, exists for compatibility.
2164
2165        Yields
2166        ------
2167        train : ndarray
2168            The training set indices for that split.
2169
2170        test : ndarray
2171            The testing set indices for that split.
2172        """
2173        ind = np.arange(len(self.test_fold))
2174        for test_index in self._iter_test_masks():
2175            train_index = ind[np.logical_not(test_index)]
2176            test_index = ind[test_index]
2177            yield train_index, test_index
2178
2179    def _iter_test_masks(self):
2180        """Generates boolean masks corresponding to test sets."""
2181        for f in self.unique_folds:
2182            test_index = np.where(self.test_fold == f)[0]
2183            test_mask = np.zeros(len(self.test_fold), dtype=bool)
2184            test_mask[test_index] = True
2185            yield test_mask
2186
2187    def get_n_splits(self, X=None, y=None, groups=None):
2188        """Returns the number of splitting iterations in the cross-validator
2189
2190        Parameters
2191        ----------
2192        X : object
2193            Always ignored, exists for compatibility.
2194
2195        y : object
2196            Always ignored, exists for compatibility.
2197
2198        groups : object
2199            Always ignored, exists for compatibility.
2200
2201        Returns
2202        -------
2203        n_splits : int
2204            Returns the number of splitting iterations in the cross-validator.
2205        """
2206        return len(self.unique_folds)
2207
2208
2209class _CVIterableWrapper(BaseCrossValidator):
2210    """Wrapper class for old style cv objects and iterables."""
2211
2212    def __init__(self, cv):
2213        self.cv = list(cv)
2214
2215    def get_n_splits(self, X=None, y=None, groups=None):
2216        """Returns the number of splitting iterations in the cross-validator
2217
2218        Parameters
2219        ----------
2220        X : object
2221            Always ignored, exists for compatibility.
2222
2223        y : object
2224            Always ignored, exists for compatibility.
2225
2226        groups : object
2227            Always ignored, exists for compatibility.
2228
2229        Returns
2230        -------
2231        n_splits : int
2232            Returns the number of splitting iterations in the cross-validator.
2233        """
2234        return len(self.cv)
2235
2236    def split(self, X=None, y=None, groups=None):
2237        """Generate indices to split data into training and test set.
2238
2239        Parameters
2240        ----------
2241        X : object
2242            Always ignored, exists for compatibility.
2243
2244        y : object
2245            Always ignored, exists for compatibility.
2246
2247        groups : object
2248            Always ignored, exists for compatibility.
2249
2250        Yields
2251        ------
2252        train : ndarray
2253            The training set indices for that split.
2254
2255        test : ndarray
2256            The testing set indices for that split.
2257        """
2258        for train, test in self.cv:
2259            yield train, test
2260
2261
2262def check_cv(cv=5, y=None, *, classifier=False):
2263    """Input checker utility for building a cross-validator
2264
2265    Parameters
2266    ----------
2267    cv : int, cross-validation generator or an iterable, default=None
2268        Determines the cross-validation splitting strategy.
2269        Possible inputs for cv are:
2270        - None, to use the default 5-fold cross validation,
2271        - integer, to specify the number of folds.
2272        - :term:`CV splitter`,
2273        - An iterable yielding (train, test) splits as arrays of indices.
2274
2275        For integer/None inputs, if classifier is True and ``y`` is either
2276        binary or multiclass, :class:`StratifiedKFold` is used. In all other
2277        cases, :class:`KFold` is used.
2278
2279        Refer :ref:`User Guide <cross_validation>` for the various
2280        cross-validation strategies that can be used here.
2281
2282        .. versionchanged:: 0.22
2283            ``cv`` default value changed from 3-fold to 5-fold.
2284
2285    y : array-like, default=None
2286        The target variable for supervised learning problems.
2287
2288    classifier : bool, default=False
2289        Whether the task is a classification task, in which case
2290        stratified KFold will be used.
2291
2292    Returns
2293    -------
2294    checked_cv : a cross-validator instance.
2295        The return value is a cross-validator which generates the train/test
2296        splits via the ``split`` method.
2297    """
2298    cv = 5 if cv is None else cv
2299    if isinstance(cv, numbers.Integral):
2300        if (
2301            classifier
2302            and (y is not None)
2303            and (type_of_target(y) in ("binary", "multiclass"))
2304        ):
2305            return StratifiedKFold(cv)
2306        else:
2307            return KFold(cv)
2308
2309    if not hasattr(cv, "split") or isinstance(cv, str):
2310        if not isinstance(cv, Iterable) or isinstance(cv, str):
2311            raise ValueError(
2312                "Expected cv as an integer, cross-validation "
2313                "object (from sklearn.model_selection) "
2314                "or an iterable. Got %s." % cv
2315            )
2316        return _CVIterableWrapper(cv)
2317
2318    return cv  # New style cv objects are passed without any modification
2319
2320
2321def train_test_split(
2322    *arrays,
2323    test_size=None,
2324    train_size=None,
2325    random_state=None,
2326    shuffle=True,
2327    stratify=None,
2328):
2329    """Split arrays or matrices into random train and test subsets.
2330
2331    Quick utility that wraps input validation and
2332    ``next(ShuffleSplit().split(X, y))`` and application to input data
2333    into a single call for splitting (and optionally subsampling) data in a
2334    oneliner.
2335
2336    Read more in the :ref:`User Guide <cross_validation>`.
2337
2338    Parameters
2339    ----------
2340    *arrays : sequence of indexables with same length / shape[0]
2341        Allowed inputs are lists, numpy arrays, scipy-sparse
2342        matrices or pandas dataframes.
2343
2344    test_size : float or int, default=None
2345        If float, should be between 0.0 and 1.0 and represent the proportion
2346        of the dataset to include in the test split. If int, represents the
2347        absolute number of test samples. If None, the value is set to the
2348        complement of the train size. If ``train_size`` is also None, it will
2349        be set to 0.25.
2350
2351    train_size : float or int, default=None
2352        If float, should be between 0.0 and 1.0 and represent the
2353        proportion of the dataset to include in the train split. If
2354        int, represents the absolute number of train samples. If None,
2355        the value is automatically set to the complement of the test size.
2356
2357    random_state : int, RandomState instance or None, default=None
2358        Controls the shuffling applied to the data before applying the split.
2359        Pass an int for reproducible output across multiple function calls.
2360        See :term:`Glossary <random_state>`.
2361
2362    shuffle : bool, default=True
2363        Whether or not to shuffle the data before splitting. If shuffle=False
2364        then stratify must be None.
2365
2366    stratify : array-like, default=None
2367        If not None, data is split in a stratified fashion, using this as
2368        the class labels.
2369        Read more in the :ref:`User Guide <stratification>`.
2370
2371    Returns
2372    -------
2373    splitting : list, length=2 * len(arrays)
2374        List containing train-test split of inputs.
2375
2376        .. versionadded:: 0.16
2377            If the input is sparse, the output will be a
2378            ``scipy.sparse.csr_matrix``. Else, output type is the same as the
2379            input type.
2380
2381    Examples
2382    --------
2383    >>> import numpy as np
2384    >>> from sklearn.model_selection import train_test_split
2385    >>> X, y = np.arange(10).reshape((5, 2)), range(5)
2386    >>> X
2387    array([[0, 1],
2388           [2, 3],
2389           [4, 5],
2390           [6, 7],
2391           [8, 9]])
2392    >>> list(y)
2393    [0, 1, 2, 3, 4]
2394
2395    >>> X_train, X_test, y_train, y_test = train_test_split(
2396    ...     X, y, test_size=0.33, random_state=42)
2397    ...
2398    >>> X_train
2399    array([[4, 5],
2400           [0, 1],
2401           [6, 7]])
2402    >>> y_train
2403    [2, 0, 3]
2404    >>> X_test
2405    array([[2, 3],
2406           [8, 9]])
2407    >>> y_test
2408    [1, 4]
2409
2410    >>> train_test_split(y, shuffle=False)
2411    [[0, 1, 2], [3, 4]]
2412    """
2413    n_arrays = len(arrays)
2414    if n_arrays == 0:
2415        raise ValueError("At least one array required as input")
2416
2417    arrays = indexable(*arrays)
2418
2419    n_samples = _num_samples(arrays[0])
2420    n_train, n_test = _validate_shuffle_split(
2421        n_samples, test_size, train_size, default_test_size=0.25
2422    )
2423
2424    if shuffle is False:
2425        if stratify is not None:
2426            raise ValueError(
2427                "Stratified train/test split is not implemented for shuffle=False"
2428            )
2429
2430        train = np.arange(n_train)
2431        test = np.arange(n_train, n_train + n_test)
2432
2433    else:
2434        if stratify is not None:
2435            CVClass = StratifiedShuffleSplit
2436        else:
2437            CVClass = ShuffleSplit
2438
2439        cv = CVClass(test_size=n_test, train_size=n_train, random_state=random_state)
2440
2441        train, test = next(cv.split(X=arrays[0], y=stratify))
2442
2443    return list(
2444        chain.from_iterable(
2445            (_safe_indexing(a, train), _safe_indexing(a, test)) for a in arrays
2446        )
2447    )
2448
2449
2450# Tell nose that train_test_split is not a test.
2451# (Needed for external libraries that may use nose.)
2452# Use setattr to avoid mypy errors when monkeypatching.
2453setattr(train_test_split, "__test__", False)
2454
2455
2456def _build_repr(self):
2457    # XXX This is copied from BaseEstimator's get_params
2458    cls = self.__class__
2459    init = getattr(cls.__init__, "deprecated_original", cls.__init__)
2460    # Ignore varargs, kw and default values and pop self
2461    init_signature = signature(init)
2462    # Consider the constructor parameters excluding 'self'
2463    if init is object.__init__:
2464        args = []
2465    else:
2466        args = sorted(
2467            [
2468                p.name
2469                for p in init_signature.parameters.values()
2470                if p.name != "self" and p.kind != p.VAR_KEYWORD
2471            ]
2472        )
2473    class_name = self.__class__.__name__
2474    params = dict()
2475    for key in args:
2476        # We need deprecation warnings to always be on in order to
2477        # catch deprecated param values.
2478        # This is set in utils/__init__.py but it gets overwritten
2479        # when running under python3 somehow.
2480        warnings.simplefilter("always", FutureWarning)
2481        try:
2482            with warnings.catch_warnings(record=True) as w:
2483                value = getattr(self, key, None)
2484                if value is None and hasattr(self, "cvargs"):
2485                    value = self.cvargs.get(key, None)
2486            if len(w) and w[0].category == FutureWarning:
2487                # if the parameter is deprecated, don't show it
2488                continue
2489        finally:
2490            warnings.filters.pop(0)
2491        params[key] = value
2492
2493    return "%s(%s)" % (class_name, _pprint(params, offset=len(class_name)))
2494
2495
2496def _yields_constant_splits(cv):
2497    # Return True if calling cv.split() always returns the same splits
2498    # We assume that if a cv doesn't have a shuffle parameter, it shuffles by
2499    # default (e.g. ShuffleSplit). If it actually doesn't shuffle (e.g.
2500    # LeaveOneOut), then it won't have a random_state parameter anyway, in
2501    # which case it will default to 0, leading to output=True
2502    shuffle = getattr(cv, "shuffle", True)
2503    random_state = getattr(cv, "random_state", 0)
2504    return isinstance(random_state, numbers.Integral) or not shuffle
2505