1import numpy as np
2import scipy.sparse as sp
3import warnings
4from abc import ABCMeta, abstractmethod
5
6# mypy error: error: Module 'sklearn.svm' has no attribute '_libsvm'
7# (and same for other imports)
8from . import _libsvm as libsvm  # type: ignore
9from . import _liblinear as liblinear  # type: ignore
10from . import _libsvm_sparse as libsvm_sparse  # type: ignore
11from ..base import BaseEstimator, ClassifierMixin
12from ..preprocessing import LabelEncoder
13from ..utils.multiclass import _ovr_decision_function
14from ..utils import check_array, check_random_state
15from ..utils import column_or_1d
16from ..utils import compute_class_weight
17from ..utils.metaestimators import available_if
18from ..utils.deprecation import deprecated
19from ..utils.extmath import safe_sparse_dot
20from ..utils.validation import check_is_fitted, _check_large_sparse
21from ..utils.validation import _num_samples
22from ..utils.validation import _check_sample_weight, check_consistent_length
23from ..utils.multiclass import check_classification_targets
24from ..exceptions import ConvergenceWarning
25from ..exceptions import NotFittedError
26
27
28LIBSVM_IMPL = ["c_svc", "nu_svc", "one_class", "epsilon_svr", "nu_svr"]
29
30
31def _one_vs_one_coef(dual_coef, n_support, support_vectors):
32    """Generate primal coefficients from dual coefficients
33    for the one-vs-one multi class LibSVM in the case
34    of a linear kernel."""
35
36    # get 1vs1 weights for all n*(n-1) classifiers.
37    # this is somewhat messy.
38    # shape of dual_coef_ is nSV * (n_classes -1)
39    # see docs for details
40    n_class = dual_coef.shape[0] + 1
41
42    # XXX we could do preallocation of coef but
43    # would have to take care in the sparse case
44    coef = []
45    sv_locs = np.cumsum(np.hstack([[0], n_support]))
46    for class1 in range(n_class):
47        # SVs for class1:
48        sv1 = support_vectors[sv_locs[class1] : sv_locs[class1 + 1], :]
49        for class2 in range(class1 + 1, n_class):
50            # SVs for class1:
51            sv2 = support_vectors[sv_locs[class2] : sv_locs[class2 + 1], :]
52
53            # dual coef for class1 SVs:
54            alpha1 = dual_coef[class2 - 1, sv_locs[class1] : sv_locs[class1 + 1]]
55            # dual coef for class2 SVs:
56            alpha2 = dual_coef[class1, sv_locs[class2] : sv_locs[class2 + 1]]
57            # build weight for class1 vs class2
58
59            coef.append(safe_sparse_dot(alpha1, sv1) + safe_sparse_dot(alpha2, sv2))
60    return coef
61
62
63class BaseLibSVM(BaseEstimator, metaclass=ABCMeta):
64    """Base class for estimators that use libsvm as backing library.
65
66    This implements support vector machine classification and regression.
67
68    Parameter documentation is in the derived `SVC` class.
69    """
70
71    # The order of these must match the integer values in LibSVM.
72    # XXX These are actually the same in the dense case. Need to factor
73    # this out.
74    _sparse_kernels = ["linear", "poly", "rbf", "sigmoid", "precomputed"]
75
76    @abstractmethod
77    def __init__(
78        self,
79        kernel,
80        degree,
81        gamma,
82        coef0,
83        tol,
84        C,
85        nu,
86        epsilon,
87        shrinking,
88        probability,
89        cache_size,
90        class_weight,
91        verbose,
92        max_iter,
93        random_state,
94    ):
95
96        if self._impl not in LIBSVM_IMPL:
97            raise ValueError(
98                "impl should be one of %s, %s was given" % (LIBSVM_IMPL, self._impl)
99            )
100
101        if gamma == 0:
102            msg = (
103                "The gamma value of 0.0 is invalid. Use 'auto' to set"
104                " gamma to a value of 1 / n_features."
105            )
106            raise ValueError(msg)
107
108        self.kernel = kernel
109        self.degree = degree
110        self.gamma = gamma
111        self.coef0 = coef0
112        self.tol = tol
113        self.C = C
114        self.nu = nu
115        self.epsilon = epsilon
116        self.shrinking = shrinking
117        self.probability = probability
118        self.cache_size = cache_size
119        self.class_weight = class_weight
120        self.verbose = verbose
121        self.max_iter = max_iter
122        self.random_state = random_state
123
124    def _more_tags(self):
125        # Used by cross_val_score.
126        return {"pairwise": self.kernel == "precomputed"}
127
128    # TODO: Remove in 1.1
129    # mypy error: Decorated property not supported
130    @deprecated(  # type: ignore
131        "Attribute `_pairwise` was deprecated in "
132        "version 0.24 and will be removed in 1.1 (renaming of 0.26)."
133    )
134    @property
135    def _pairwise(self):
136        # Used by cross_val_score.
137        return self.kernel == "precomputed"
138
139    def fit(self, X, y, sample_weight=None):
140        """Fit the SVM model according to the given training data.
141
142        Parameters
143        ----------
144        X : {array-like, sparse matrix} of shape (n_samples, n_features) \
145                or (n_samples, n_samples)
146            Training vectors, where `n_samples` is the number of samples
147            and `n_features` is the number of features.
148            For kernel="precomputed", the expected shape of X is
149            (n_samples, n_samples).
150
151        y : array-like of shape (n_samples,)
152            Target values (class labels in classification, real numbers in
153            regression).
154
155        sample_weight : array-like of shape (n_samples,), default=None
156            Per-sample weights. Rescale C per sample. Higher weights
157            force the classifier to put more emphasis on these points.
158
159        Returns
160        -------
161        self : object
162            Fitted estimator.
163
164        Notes
165        -----
166        If X and y are not C-ordered and contiguous arrays of np.float64 and
167        X is not a scipy.sparse.csr_matrix, X and/or y may be copied.
168
169        If X is a dense array, then the other methods will not support sparse
170        matrices as input.
171        """
172
173        rnd = check_random_state(self.random_state)
174
175        sparse = sp.isspmatrix(X)
176        if sparse and self.kernel == "precomputed":
177            raise TypeError("Sparse precomputed kernels are not supported.")
178        self._sparse = sparse and not callable(self.kernel)
179
180        if hasattr(self, "decision_function_shape"):
181            if self.decision_function_shape not in ("ovr", "ovo"):
182                raise ValueError(
183                    "decision_function_shape must be either 'ovr' or 'ovo', "
184                    f"got {self.decision_function_shape}."
185                )
186
187        if callable(self.kernel):
188            check_consistent_length(X, y)
189        else:
190            X, y = self._validate_data(
191                X,
192                y,
193                dtype=np.float64,
194                order="C",
195                accept_sparse="csr",
196                accept_large_sparse=False,
197            )
198
199        y = self._validate_targets(y)
200
201        sample_weight = np.asarray(
202            [] if sample_weight is None else sample_weight, dtype=np.float64
203        )
204        solver_type = LIBSVM_IMPL.index(self._impl)
205
206        # input validation
207        n_samples = _num_samples(X)
208        if solver_type != 2 and n_samples != y.shape[0]:
209            raise ValueError(
210                "X and y have incompatible shapes.\n"
211                + "X has %s samples, but y has %s." % (n_samples, y.shape[0])
212            )
213
214        if self.kernel == "precomputed" and n_samples != X.shape[1]:
215            raise ValueError(
216                "Precomputed matrix must be a square matrix."
217                " Input is a {}x{} matrix.".format(X.shape[0], X.shape[1])
218            )
219
220        if sample_weight.shape[0] > 0 and sample_weight.shape[0] != n_samples:
221            raise ValueError(
222                "sample_weight and X have incompatible shapes: "
223                "%r vs %r\n"
224                "Note: Sparse matrices cannot be indexed w/"
225                "boolean masks (use `indices=True` in CV)."
226                % (sample_weight.shape, X.shape)
227            )
228
229        kernel = "precomputed" if callable(self.kernel) else self.kernel
230
231        if kernel == "precomputed":
232            # unused but needs to be a float for cython code that ignores
233            # it anyway
234            self._gamma = 0.0
235        elif isinstance(self.gamma, str):
236            if self.gamma == "scale":
237                # var = E[X^2] - E[X]^2 if sparse
238                X_var = (X.multiply(X)).mean() - (X.mean()) ** 2 if sparse else X.var()
239                self._gamma = 1.0 / (X.shape[1] * X_var) if X_var != 0 else 1.0
240            elif self.gamma == "auto":
241                self._gamma = 1.0 / X.shape[1]
242            else:
243                raise ValueError(
244                    "When 'gamma' is a string, it should be either 'scale' or "
245                    "'auto'. Got '{}' instead.".format(self.gamma)
246                )
247        else:
248            self._gamma = self.gamma
249
250        fit = self._sparse_fit if self._sparse else self._dense_fit
251        if self.verbose:
252            print("[LibSVM]", end="")
253
254        seed = rnd.randint(np.iinfo("i").max)
255        fit(X, y, sample_weight, solver_type, kernel, random_seed=seed)
256        # see comment on the other call to np.iinfo in this file
257
258        self.shape_fit_ = X.shape if hasattr(X, "shape") else (n_samples,)
259
260        # In binary case, we need to flip the sign of coef, intercept and
261        # decision function. Use self._intercept_ and self._dual_coef_
262        # internally.
263        self._intercept_ = self.intercept_.copy()
264        self._dual_coef_ = self.dual_coef_
265        if self._impl in ["c_svc", "nu_svc"] and len(self.classes_) == 2:
266            self.intercept_ *= -1
267            self.dual_coef_ = -self.dual_coef_
268
269        return self
270
271    def _validate_targets(self, y):
272        """Validation of y and class_weight.
273
274        Default implementation for SVR and one-class; overridden in BaseSVC.
275        """
276        # XXX this is ugly.
277        # Regression models should not have a class_weight_ attribute.
278        self.class_weight_ = np.empty(0)
279        return column_or_1d(y, warn=True).astype(np.float64, copy=False)
280
281    def _warn_from_fit_status(self):
282        assert self.fit_status_ in (0, 1)
283        if self.fit_status_ == 1:
284            warnings.warn(
285                "Solver terminated early (max_iter=%i)."
286                "  Consider pre-processing your data with"
287                " StandardScaler or MinMaxScaler."
288                % self.max_iter,
289                ConvergenceWarning,
290            )
291
292    def _dense_fit(self, X, y, sample_weight, solver_type, kernel, random_seed):
293        if callable(self.kernel):
294            # you must store a reference to X to compute the kernel in predict
295            # TODO: add keyword copy to copy on demand
296            self.__Xfit = X
297            X = self._compute_kernel(X)
298
299            if X.shape[0] != X.shape[1]:
300                raise ValueError("X.shape[0] should be equal to X.shape[1]")
301
302        libsvm.set_verbosity_wrap(self.verbose)
303
304        # we don't pass **self.get_params() to allow subclasses to
305        # add other parameters to __init__
306        (
307            self.support_,
308            self.support_vectors_,
309            self._n_support,
310            self.dual_coef_,
311            self.intercept_,
312            self._probA,
313            self._probB,
314            self.fit_status_,
315        ) = libsvm.fit(
316            X,
317            y,
318            svm_type=solver_type,
319            sample_weight=sample_weight,
320            class_weight=self.class_weight_,
321            kernel=kernel,
322            C=self.C,
323            nu=self.nu,
324            probability=self.probability,
325            degree=self.degree,
326            shrinking=self.shrinking,
327            tol=self.tol,
328            cache_size=self.cache_size,
329            coef0=self.coef0,
330            gamma=self._gamma,
331            epsilon=self.epsilon,
332            max_iter=self.max_iter,
333            random_seed=random_seed,
334        )
335
336        self._warn_from_fit_status()
337
338    def _sparse_fit(self, X, y, sample_weight, solver_type, kernel, random_seed):
339        X.data = np.asarray(X.data, dtype=np.float64, order="C")
340        X.sort_indices()
341
342        kernel_type = self._sparse_kernels.index(kernel)
343
344        libsvm_sparse.set_verbosity_wrap(self.verbose)
345
346        (
347            self.support_,
348            self.support_vectors_,
349            dual_coef_data,
350            self.intercept_,
351            self._n_support,
352            self._probA,
353            self._probB,
354            self.fit_status_,
355        ) = libsvm_sparse.libsvm_sparse_train(
356            X.shape[1],
357            X.data,
358            X.indices,
359            X.indptr,
360            y,
361            solver_type,
362            kernel_type,
363            self.degree,
364            self._gamma,
365            self.coef0,
366            self.tol,
367            self.C,
368            self.class_weight_,
369            sample_weight,
370            self.nu,
371            self.cache_size,
372            self.epsilon,
373            int(self.shrinking),
374            int(self.probability),
375            self.max_iter,
376            random_seed,
377        )
378
379        self._warn_from_fit_status()
380
381        if hasattr(self, "classes_"):
382            n_class = len(self.classes_) - 1
383        else:  # regression
384            n_class = 1
385        n_SV = self.support_vectors_.shape[0]
386
387        dual_coef_indices = np.tile(np.arange(n_SV), n_class)
388        if not n_SV:
389            self.dual_coef_ = sp.csr_matrix([])
390        else:
391            dual_coef_indptr = np.arange(
392                0, dual_coef_indices.size + 1, dual_coef_indices.size / n_class
393            )
394            self.dual_coef_ = sp.csr_matrix(
395                (dual_coef_data, dual_coef_indices, dual_coef_indptr), (n_class, n_SV)
396            )
397
398    def predict(self, X):
399        """Perform regression on samples in X.
400
401        For an one-class model, +1 (inlier) or -1 (outlier) is returned.
402
403        Parameters
404        ----------
405        X : {array-like, sparse matrix} of shape (n_samples, n_features)
406            For kernel="precomputed", the expected shape of X is
407            (n_samples_test, n_samples_train).
408
409        Returns
410        -------
411        y_pred : ndarray of shape (n_samples,)
412            The predicted values.
413        """
414        X = self._validate_for_predict(X)
415        predict = self._sparse_predict if self._sparse else self._dense_predict
416        return predict(X)
417
418    def _dense_predict(self, X):
419        X = self._compute_kernel(X)
420        if X.ndim == 1:
421            X = check_array(X, order="C", accept_large_sparse=False)
422
423        kernel = self.kernel
424        if callable(self.kernel):
425            kernel = "precomputed"
426            if X.shape[1] != self.shape_fit_[0]:
427                raise ValueError(
428                    "X.shape[1] = %d should be equal to %d, "
429                    "the number of samples at training time"
430                    % (X.shape[1], self.shape_fit_[0])
431                )
432
433        svm_type = LIBSVM_IMPL.index(self._impl)
434
435        return libsvm.predict(
436            X,
437            self.support_,
438            self.support_vectors_,
439            self._n_support,
440            self._dual_coef_,
441            self._intercept_,
442            self._probA,
443            self._probB,
444            svm_type=svm_type,
445            kernel=kernel,
446            degree=self.degree,
447            coef0=self.coef0,
448            gamma=self._gamma,
449            cache_size=self.cache_size,
450        )
451
452    def _sparse_predict(self, X):
453        # Precondition: X is a csr_matrix of dtype np.float64.
454        kernel = self.kernel
455        if callable(kernel):
456            kernel = "precomputed"
457
458        kernel_type = self._sparse_kernels.index(kernel)
459
460        C = 0.0  # C is not useful here
461
462        return libsvm_sparse.libsvm_sparse_predict(
463            X.data,
464            X.indices,
465            X.indptr,
466            self.support_vectors_.data,
467            self.support_vectors_.indices,
468            self.support_vectors_.indptr,
469            self._dual_coef_.data,
470            self._intercept_,
471            LIBSVM_IMPL.index(self._impl),
472            kernel_type,
473            self.degree,
474            self._gamma,
475            self.coef0,
476            self.tol,
477            C,
478            self.class_weight_,
479            self.nu,
480            self.epsilon,
481            self.shrinking,
482            self.probability,
483            self._n_support,
484            self._probA,
485            self._probB,
486        )
487
488    def _compute_kernel(self, X):
489        """Return the data transformed by a callable kernel"""
490        if callable(self.kernel):
491            # in the case of precomputed kernel given as a function, we
492            # have to compute explicitly the kernel matrix
493            kernel = self.kernel(X, self.__Xfit)
494            if sp.issparse(kernel):
495                kernel = kernel.toarray()
496            X = np.asarray(kernel, dtype=np.float64, order="C")
497        return X
498
499    def _decision_function(self, X):
500        """Evaluates the decision function for the samples in X.
501
502        Parameters
503        ----------
504        X : array-like of shape (n_samples, n_features)
505
506        Returns
507        -------
508        X : array-like of shape (n_samples, n_class * (n_class-1) / 2)
509            Returns the decision function of the sample for each class
510            in the model.
511        """
512        # NOTE: _validate_for_predict contains check for is_fitted
513        # hence must be placed before any other attributes are used.
514        X = self._validate_for_predict(X)
515        X = self._compute_kernel(X)
516
517        if self._sparse:
518            dec_func = self._sparse_decision_function(X)
519        else:
520            dec_func = self._dense_decision_function(X)
521
522        # In binary case, we need to flip the sign of coef, intercept and
523        # decision function.
524        if self._impl in ["c_svc", "nu_svc"] and len(self.classes_) == 2:
525            return -dec_func.ravel()
526
527        return dec_func
528
529    def _dense_decision_function(self, X):
530        X = check_array(X, dtype=np.float64, order="C", accept_large_sparse=False)
531
532        kernel = self.kernel
533        if callable(kernel):
534            kernel = "precomputed"
535
536        return libsvm.decision_function(
537            X,
538            self.support_,
539            self.support_vectors_,
540            self._n_support,
541            self._dual_coef_,
542            self._intercept_,
543            self._probA,
544            self._probB,
545            svm_type=LIBSVM_IMPL.index(self._impl),
546            kernel=kernel,
547            degree=self.degree,
548            cache_size=self.cache_size,
549            coef0=self.coef0,
550            gamma=self._gamma,
551        )
552
553    def _sparse_decision_function(self, X):
554        X.data = np.asarray(X.data, dtype=np.float64, order="C")
555
556        kernel = self.kernel
557        if hasattr(kernel, "__call__"):
558            kernel = "precomputed"
559
560        kernel_type = self._sparse_kernels.index(kernel)
561
562        return libsvm_sparse.libsvm_sparse_decision_function(
563            X.data,
564            X.indices,
565            X.indptr,
566            self.support_vectors_.data,
567            self.support_vectors_.indices,
568            self.support_vectors_.indptr,
569            self._dual_coef_.data,
570            self._intercept_,
571            LIBSVM_IMPL.index(self._impl),
572            kernel_type,
573            self.degree,
574            self._gamma,
575            self.coef0,
576            self.tol,
577            self.C,
578            self.class_weight_,
579            self.nu,
580            self.epsilon,
581            self.shrinking,
582            self.probability,
583            self._n_support,
584            self._probA,
585            self._probB,
586        )
587
588    def _validate_for_predict(self, X):
589        check_is_fitted(self)
590
591        if not callable(self.kernel):
592            X = self._validate_data(
593                X,
594                accept_sparse="csr",
595                dtype=np.float64,
596                order="C",
597                accept_large_sparse=False,
598                reset=False,
599            )
600
601        if self._sparse and not sp.isspmatrix(X):
602            X = sp.csr_matrix(X)
603        if self._sparse:
604            X.sort_indices()
605
606        if sp.issparse(X) and not self._sparse and not callable(self.kernel):
607            raise ValueError(
608                "cannot use sparse input in %r trained on dense data"
609                % type(self).__name__
610            )
611
612        if self.kernel == "precomputed":
613            if X.shape[1] != self.shape_fit_[0]:
614                raise ValueError(
615                    "X.shape[1] = %d should be equal to %d, "
616                    "the number of samples at training time"
617                    % (X.shape[1], self.shape_fit_[0])
618                )
619        # Fixes https://nvd.nist.gov/vuln/detail/CVE-2020-28975
620        # Check that _n_support is consistent with support_vectors
621        sv = self.support_vectors_
622        if not self._sparse and sv.size > 0 and self.n_support_.sum() != sv.shape[0]:
623            raise ValueError(
624                f"The internal representation of {self.__class__.__name__} was altered"
625            )
626        return X
627
628    @property
629    def coef_(self):
630        """Weights assigned to the features when `kernel="linear"`.
631
632        Returns
633        -------
634        ndarray of shape (n_features, n_classes)
635        """
636        if self.kernel != "linear":
637            raise AttributeError("coef_ is only available when using a linear kernel")
638
639        coef = self._get_coef()
640
641        # coef_ being a read-only property, it's better to mark the value as
642        # immutable to avoid hiding potential bugs for the unsuspecting user.
643        if sp.issparse(coef):
644            # sparse matrix do not have global flags
645            coef.data.flags.writeable = False
646        else:
647            # regular dense array
648            coef.flags.writeable = False
649        return coef
650
651    def _get_coef(self):
652        return safe_sparse_dot(self._dual_coef_, self.support_vectors_)
653
654    @property
655    def n_support_(self):
656        """Number of support vectors for each class."""
657        try:
658            check_is_fitted(self)
659        except NotFittedError:
660            raise AttributeError
661
662        svm_type = LIBSVM_IMPL.index(self._impl)
663        if svm_type in (0, 1):
664            return self._n_support
665        else:
666            # SVR and OneClass
667            # _n_support has size 2, we make it size 1
668            return np.array([self._n_support[0]])
669
670
671class BaseSVC(ClassifierMixin, BaseLibSVM, metaclass=ABCMeta):
672    """ABC for LibSVM-based classifiers."""
673
674    @abstractmethod
675    def __init__(
676        self,
677        kernel,
678        degree,
679        gamma,
680        coef0,
681        tol,
682        C,
683        nu,
684        shrinking,
685        probability,
686        cache_size,
687        class_weight,
688        verbose,
689        max_iter,
690        decision_function_shape,
691        random_state,
692        break_ties,
693    ):
694        self.decision_function_shape = decision_function_shape
695        self.break_ties = break_ties
696        super().__init__(
697            kernel=kernel,
698            degree=degree,
699            gamma=gamma,
700            coef0=coef0,
701            tol=tol,
702            C=C,
703            nu=nu,
704            epsilon=0.0,
705            shrinking=shrinking,
706            probability=probability,
707            cache_size=cache_size,
708            class_weight=class_weight,
709            verbose=verbose,
710            max_iter=max_iter,
711            random_state=random_state,
712        )
713
714    def _validate_targets(self, y):
715        y_ = column_or_1d(y, warn=True)
716        check_classification_targets(y)
717        cls, y = np.unique(y_, return_inverse=True)
718        self.class_weight_ = compute_class_weight(self.class_weight, classes=cls, y=y_)
719        if len(cls) < 2:
720            raise ValueError(
721                "The number of classes has to be greater than one; got %d class"
722                % len(cls)
723            )
724
725        self.classes_ = cls
726
727        return np.asarray(y, dtype=np.float64, order="C")
728
729    def decision_function(self, X):
730        """Evaluate the decision function for the samples in X.
731
732        Parameters
733        ----------
734        X : array-like of shape (n_samples, n_features)
735            The input samples.
736
737        Returns
738        -------
739        X : ndarray of shape (n_samples, n_classes * (n_classes-1) / 2)
740            Returns the decision function of the sample for each class
741            in the model.
742            If decision_function_shape='ovr', the shape is (n_samples,
743            n_classes).
744
745        Notes
746        -----
747        If decision_function_shape='ovo', the function values are proportional
748        to the distance of the samples X to the separating hyperplane. If the
749        exact distances are required, divide the function values by the norm of
750        the weight vector (``coef_``). See also `this question
751        <https://stats.stackexchange.com/questions/14876/
752        interpreting-distance-from-hyperplane-in-svm>`_ for further details.
753        If decision_function_shape='ovr', the decision function is a monotonic
754        transformation of ovo decision function.
755        """
756        dec = self._decision_function(X)
757        if self.decision_function_shape == "ovr" and len(self.classes_) > 2:
758            return _ovr_decision_function(dec < 0, -dec, len(self.classes_))
759        return dec
760
761    def predict(self, X):
762        """Perform classification on samples in X.
763
764        For an one-class model, +1 or -1 is returned.
765
766        Parameters
767        ----------
768        X : {array-like, sparse matrix} of shape (n_samples, n_features) or \
769                (n_samples_test, n_samples_train)
770            For kernel="precomputed", the expected shape of X is
771            (n_samples_test, n_samples_train).
772
773        Returns
774        -------
775        y_pred : ndarray of shape (n_samples,)
776            Class labels for samples in X.
777        """
778        check_is_fitted(self)
779        if self.break_ties and self.decision_function_shape == "ovo":
780            raise ValueError(
781                "break_ties must be False when decision_function_shape is 'ovo'"
782            )
783
784        if (
785            self.break_ties
786            and self.decision_function_shape == "ovr"
787            and len(self.classes_) > 2
788        ):
789            y = np.argmax(self.decision_function(X), axis=1)
790        else:
791            y = super().predict(X)
792        return self.classes_.take(np.asarray(y, dtype=np.intp))
793
794    # Hacky way of getting predict_proba to raise an AttributeError when
795    # probability=False using properties. Do not use this in new code; when
796    # probabilities are not available depending on a setting, introduce two
797    # estimators.
798    def _check_proba(self):
799        if not self.probability:
800            raise AttributeError(
801                "predict_proba is not available when  probability=False"
802            )
803        if self._impl not in ("c_svc", "nu_svc"):
804            raise AttributeError("predict_proba only implemented for SVC and NuSVC")
805        return True
806
807    @available_if(_check_proba)
808    def predict_proba(self, X):
809        """Compute probabilities of possible outcomes for samples in X.
810
811        The model need to have probability information computed at training
812        time: fit with attribute `probability` set to True.
813
814        Parameters
815        ----------
816        X : array-like of shape (n_samples, n_features)
817            For kernel="precomputed", the expected shape of X is
818            (n_samples_test, n_samples_train).
819
820        Returns
821        -------
822        T : ndarray of shape (n_samples, n_classes)
823            Returns the probability of the sample for each class in
824            the model. The columns correspond to the classes in sorted
825            order, as they appear in the attribute :term:`classes_`.
826
827        Notes
828        -----
829        The probability model is created using cross validation, so
830        the results can be slightly different than those obtained by
831        predict. Also, it will produce meaningless results on very small
832        datasets.
833        """
834        X = self._validate_for_predict(X)
835        if self.probA_.size == 0 or self.probB_.size == 0:
836            raise NotFittedError(
837                "predict_proba is not available when fitted with probability=False"
838            )
839        pred_proba = (
840            self._sparse_predict_proba if self._sparse else self._dense_predict_proba
841        )
842        return pred_proba(X)
843
844    @available_if(_check_proba)
845    def predict_log_proba(self, X):
846        """Compute log probabilities of possible outcomes for samples in X.
847
848        The model need to have probability information computed at training
849        time: fit with attribute `probability` set to True.
850
851        Parameters
852        ----------
853        X : array-like of shape (n_samples, n_features) or \
854                (n_samples_test, n_samples_train)
855            For kernel="precomputed", the expected shape of X is
856            (n_samples_test, n_samples_train).
857
858        Returns
859        -------
860        T : ndarray of shape (n_samples, n_classes)
861            Returns the log-probabilities of the sample for each class in
862            the model. The columns correspond to the classes in sorted
863            order, as they appear in the attribute :term:`classes_`.
864
865        Notes
866        -----
867        The probability model is created using cross validation, so
868        the results can be slightly different than those obtained by
869        predict. Also, it will produce meaningless results on very small
870        datasets.
871        """
872        return np.log(self.predict_proba(X))
873
874    def _dense_predict_proba(self, X):
875        X = self._compute_kernel(X)
876
877        kernel = self.kernel
878        if callable(kernel):
879            kernel = "precomputed"
880
881        svm_type = LIBSVM_IMPL.index(self._impl)
882        pprob = libsvm.predict_proba(
883            X,
884            self.support_,
885            self.support_vectors_,
886            self._n_support,
887            self._dual_coef_,
888            self._intercept_,
889            self._probA,
890            self._probB,
891            svm_type=svm_type,
892            kernel=kernel,
893            degree=self.degree,
894            cache_size=self.cache_size,
895            coef0=self.coef0,
896            gamma=self._gamma,
897        )
898
899        return pprob
900
901    def _sparse_predict_proba(self, X):
902        X.data = np.asarray(X.data, dtype=np.float64, order="C")
903
904        kernel = self.kernel
905        if callable(kernel):
906            kernel = "precomputed"
907
908        kernel_type = self._sparse_kernels.index(kernel)
909
910        return libsvm_sparse.libsvm_sparse_predict_proba(
911            X.data,
912            X.indices,
913            X.indptr,
914            self.support_vectors_.data,
915            self.support_vectors_.indices,
916            self.support_vectors_.indptr,
917            self._dual_coef_.data,
918            self._intercept_,
919            LIBSVM_IMPL.index(self._impl),
920            kernel_type,
921            self.degree,
922            self._gamma,
923            self.coef0,
924            self.tol,
925            self.C,
926            self.class_weight_,
927            self.nu,
928            self.epsilon,
929            self.shrinking,
930            self.probability,
931            self._n_support,
932            self._probA,
933            self._probB,
934        )
935
936    def _get_coef(self):
937        if self.dual_coef_.shape[0] == 1:
938            # binary classifier
939            coef = safe_sparse_dot(self.dual_coef_, self.support_vectors_)
940        else:
941            # 1vs1 classifier
942            coef = _one_vs_one_coef(
943                self.dual_coef_, self._n_support, self.support_vectors_
944            )
945            if sp.issparse(coef[0]):
946                coef = sp.vstack(coef).tocsr()
947            else:
948                coef = np.vstack(coef)
949
950        return coef
951
952    @property
953    def probA_(self):
954        """Parameter learned in Platt scaling when `probability=True`.
955
956        Returns
957        -------
958        ndarray of shape  (n_classes * (n_classes - 1) / 2)
959        """
960        return self._probA
961
962    @property
963    def probB_(self):
964        """Parameter learned in Platt scaling when `probability=True`.
965
966        Returns
967        -------
968        ndarray of shape  (n_classes * (n_classes - 1) / 2)
969        """
970        return self._probB
971
972
973def _get_liblinear_solver_type(multi_class, penalty, loss, dual):
974    """Find the liblinear magic number for the solver.
975
976    This number depends on the values of the following attributes:
977      - multi_class
978      - penalty
979      - loss
980      - dual
981
982    The same number is also internally used by LibLinear to determine
983    which solver to use.
984    """
985    # nested dicts containing level 1: available loss functions,
986    # level2: available penalties for the given loss function,
987    # level3: whether the dual solver is available for the specified
988    # combination of loss function and penalty
989    _solver_type_dict = {
990        "logistic_regression": {"l1": {False: 6}, "l2": {False: 0, True: 7}},
991        "hinge": {"l2": {True: 3}},
992        "squared_hinge": {"l1": {False: 5}, "l2": {False: 2, True: 1}},
993        "epsilon_insensitive": {"l2": {True: 13}},
994        "squared_epsilon_insensitive": {"l2": {False: 11, True: 12}},
995        "crammer_singer": 4,
996    }
997
998    if multi_class == "crammer_singer":
999        return _solver_type_dict[multi_class]
1000    elif multi_class != "ovr":
1001        raise ValueError(
1002            "`multi_class` must be one of `ovr`, `crammer_singer`, got %r" % multi_class
1003        )
1004
1005    _solver_pen = _solver_type_dict.get(loss, None)
1006    if _solver_pen is None:
1007        error_string = "loss='%s' is not supported" % loss
1008    else:
1009        _solver_dual = _solver_pen.get(penalty, None)
1010        if _solver_dual is None:
1011            error_string = (
1012                "The combination of penalty='%s' and loss='%s' is not supported"
1013                % (penalty, loss)
1014            )
1015        else:
1016            solver_num = _solver_dual.get(dual, None)
1017            if solver_num is None:
1018                error_string = (
1019                    "The combination of penalty='%s' and "
1020                    "loss='%s' are not supported when dual=%s" % (penalty, loss, dual)
1021                )
1022            else:
1023                return solver_num
1024    raise ValueError(
1025        "Unsupported set of arguments: %s, Parameters: penalty=%r, loss=%r, dual=%r"
1026        % (error_string, penalty, loss, dual)
1027    )
1028
1029
1030def _fit_liblinear(
1031    X,
1032    y,
1033    C,
1034    fit_intercept,
1035    intercept_scaling,
1036    class_weight,
1037    penalty,
1038    dual,
1039    verbose,
1040    max_iter,
1041    tol,
1042    random_state=None,
1043    multi_class="ovr",
1044    loss="logistic_regression",
1045    epsilon=0.1,
1046    sample_weight=None,
1047):
1048    """Used by Logistic Regression (and CV) and LinearSVC/LinearSVR.
1049
1050    Preprocessing is done in this function before supplying it to liblinear.
1051
1052    Parameters
1053    ----------
1054    X : {array-like, sparse matrix} of shape (n_samples, n_features)
1055        Training vector, where `n_samples` is the number of samples and
1056        `n_features` is the number of features.
1057
1058    y : array-like of shape (n_samples,)
1059        Target vector relative to X
1060
1061    C : float
1062        Inverse of cross-validation parameter. Lower the C, the more
1063        the penalization.
1064
1065    fit_intercept : bool
1066        Whether or not to fit the intercept, that is to add a intercept
1067        term to the decision function.
1068
1069    intercept_scaling : float
1070        LibLinear internally penalizes the intercept and this term is subject
1071        to regularization just like the other terms of the feature vector.
1072        In order to avoid this, one should increase the intercept_scaling.
1073        such that the feature vector becomes [x, intercept_scaling].
1074
1075    class_weight : dict or 'balanced', default=None
1076        Weights associated with classes in the form ``{class_label: weight}``.
1077        If not given, all classes are supposed to have weight one. For
1078        multi-output problems, a list of dicts can be provided in the same
1079        order as the columns of y.
1080
1081        The "balanced" mode uses the values of y to automatically adjust
1082        weights inversely proportional to class frequencies in the input data
1083        as ``n_samples / (n_classes * np.bincount(y))``
1084
1085    penalty : {'l1', 'l2'}
1086        The norm of the penalty used in regularization.
1087
1088    dual : bool
1089        Dual or primal formulation,
1090
1091    verbose : int
1092        Set verbose to any positive number for verbosity.
1093
1094    max_iter : int
1095        Number of iterations.
1096
1097    tol : float
1098        Stopping condition.
1099
1100    random_state : int, RandomState instance or None, default=None
1101        Controls the pseudo random number generation for shuffling the data.
1102        Pass an int for reproducible output across multiple function calls.
1103        See :term:`Glossary <random_state>`.
1104
1105    multi_class : {'ovr', 'crammer_singer'}, default='ovr'
1106        `ovr` trains n_classes one-vs-rest classifiers, while `crammer_singer`
1107        optimizes a joint objective over all classes.
1108        While `crammer_singer` is interesting from an theoretical perspective
1109        as it is consistent it is seldom used in practice and rarely leads to
1110        better accuracy and is more expensive to compute.
1111        If `crammer_singer` is chosen, the options loss, penalty and dual will
1112        be ignored.
1113
1114    loss : {'logistic_regression', 'hinge', 'squared_hinge', \
1115            'epsilon_insensitive', 'squared_epsilon_insensitive}, \
1116            default='logistic_regression'
1117        The loss function used to fit the model.
1118
1119    epsilon : float, default=0.1
1120        Epsilon parameter in the epsilon-insensitive loss function. Note
1121        that the value of this parameter depends on the scale of the target
1122        variable y. If unsure, set epsilon=0.
1123
1124    sample_weight : array-like of shape (n_samples,), default=None
1125        Weights assigned to each sample.
1126
1127    Returns
1128    -------
1129    coef_ : ndarray of shape (n_features, n_features + 1)
1130        The coefficient vector got by minimizing the objective function.
1131
1132    intercept_ : float
1133        The intercept term added to the vector.
1134
1135    n_iter_ : int
1136        Maximum number of iterations run across all classes.
1137    """
1138    if loss not in ["epsilon_insensitive", "squared_epsilon_insensitive"]:
1139        enc = LabelEncoder()
1140        y_ind = enc.fit_transform(y)
1141        classes_ = enc.classes_
1142        if len(classes_) < 2:
1143            raise ValueError(
1144                "This solver needs samples of at least 2 classes"
1145                " in the data, but the data contains only one"
1146                " class: %r"
1147                % classes_[0]
1148            )
1149
1150        class_weight_ = compute_class_weight(class_weight, classes=classes_, y=y)
1151    else:
1152        class_weight_ = np.empty(0, dtype=np.float64)
1153        y_ind = y
1154    liblinear.set_verbosity_wrap(verbose)
1155    rnd = check_random_state(random_state)
1156    if verbose:
1157        print("[LibLinear]", end="")
1158
1159    # LinearSVC breaks when intercept_scaling is <= 0
1160    bias = -1.0
1161    if fit_intercept:
1162        if intercept_scaling <= 0:
1163            raise ValueError(
1164                "Intercept scaling is %r but needs to be greater "
1165                "than 0. To disable fitting an intercept,"
1166                " set fit_intercept=False." % intercept_scaling
1167            )
1168        else:
1169            bias = intercept_scaling
1170
1171    libsvm.set_verbosity_wrap(verbose)
1172    libsvm_sparse.set_verbosity_wrap(verbose)
1173    liblinear.set_verbosity_wrap(verbose)
1174
1175    # Liblinear doesn't support 64bit sparse matrix indices yet
1176    if sp.issparse(X):
1177        _check_large_sparse(X)
1178
1179    # LibLinear wants targets as doubles, even for classification
1180    y_ind = np.asarray(y_ind, dtype=np.float64).ravel()
1181    y_ind = np.require(y_ind, requirements="W")
1182
1183    sample_weight = _check_sample_weight(sample_weight, X, dtype=np.float64)
1184
1185    solver_type = _get_liblinear_solver_type(multi_class, penalty, loss, dual)
1186    raw_coef_, n_iter_ = liblinear.train_wrap(
1187        X,
1188        y_ind,
1189        sp.isspmatrix(X),
1190        solver_type,
1191        tol,
1192        bias,
1193        C,
1194        class_weight_,
1195        max_iter,
1196        rnd.randint(np.iinfo("i").max),
1197        epsilon,
1198        sample_weight,
1199    )
1200    # Regarding rnd.randint(..) in the above signature:
1201    # seed for srand in range [0..INT_MAX); due to limitations in Numpy
1202    # on 32-bit platforms, we can't get to the UINT_MAX limit that
1203    # srand supports
1204    n_iter_ = max(n_iter_)
1205    if n_iter_ >= max_iter:
1206        warnings.warn(
1207            "Liblinear failed to converge, increase the number of iterations.",
1208            ConvergenceWarning,
1209        )
1210
1211    if fit_intercept:
1212        coef_ = raw_coef_[:, :-1]
1213        intercept_ = intercept_scaling * raw_coef_[:, -1]
1214    else:
1215        coef_ = raw_coef_
1216        intercept_ = 0.0
1217
1218    return coef_, intercept_, n_iter_
1219