1# Authors: Rob Zinkov, Mathieu Blondel 2# License: BSD 3 clause 3 4from ._stochastic_gradient import BaseSGDClassifier 5from ._stochastic_gradient import BaseSGDRegressor 6from ._stochastic_gradient import DEFAULT_EPSILON 7 8 9class PassiveAggressiveClassifier(BaseSGDClassifier): 10 """Passive Aggressive Classifier. 11 12 Read more in the :ref:`User Guide <passive_aggressive>`. 13 14 Parameters 15 ---------- 16 C : float, default=1.0 17 Maximum step size (regularization). Defaults to 1.0. 18 19 fit_intercept : bool, default=True 20 Whether the intercept should be estimated or not. If False, the 21 data is assumed to be already centered. 22 23 max_iter : int, default=1000 24 The maximum number of passes over the training data (aka epochs). 25 It only impacts the behavior in the ``fit`` method, and not the 26 :meth:`partial_fit` method. 27 28 .. versionadded:: 0.19 29 30 tol : float or None, default=1e-3 31 The stopping criterion. If it is not None, the iterations will stop 32 when (loss > previous_loss - tol). 33 34 .. versionadded:: 0.19 35 36 early_stopping : bool, default=False 37 Whether to use early stopping to terminate training when validation. 38 score is not improving. If set to True, it will automatically set aside 39 a stratified fraction of training data as validation and terminate 40 training when validation score is not improving by at least tol for 41 n_iter_no_change consecutive epochs. 42 43 .. versionadded:: 0.20 44 45 validation_fraction : float, default=0.1 46 The proportion of training data to set aside as validation set for 47 early stopping. Must be between 0 and 1. 48 Only used if early_stopping is True. 49 50 .. versionadded:: 0.20 51 52 n_iter_no_change : int, default=5 53 Number of iterations with no improvement to wait before early stopping. 54 55 .. versionadded:: 0.20 56 57 shuffle : bool, default=True 58 Whether or not the training data should be shuffled after each epoch. 59 60 verbose : int, default=0 61 The verbosity level. 62 63 loss : str, default="hinge" 64 The loss function to be used: 65 hinge: equivalent to PA-I in the reference paper. 66 squared_hinge: equivalent to PA-II in the reference paper. 67 68 n_jobs : int or None, default=None 69 The number of CPUs to use to do the OVA (One Versus All, for 70 multi-class problems) computation. 71 ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. 72 ``-1`` means using all processors. See :term:`Glossary <n_jobs>` 73 for more details. 74 75 random_state : int, RandomState instance, default=None 76 Used to shuffle the training data, when ``shuffle`` is set to 77 ``True``. Pass an int for reproducible output across multiple 78 function calls. 79 See :term:`Glossary <random_state>`. 80 81 warm_start : bool, default=False 82 When set to True, reuse the solution of the previous call to fit as 83 initialization, otherwise, just erase the previous solution. 84 See :term:`the Glossary <warm_start>`. 85 86 Repeatedly calling fit or partial_fit when warm_start is True can 87 result in a different solution than when calling fit a single time 88 because of the way the data is shuffled. 89 90 class_weight : dict, {class_label: weight} or "balanced" or None, \ 91 default=None 92 Preset for the class_weight fit parameter. 93 94 Weights associated with classes. If not given, all classes 95 are supposed to have weight one. 96 97 The "balanced" mode uses the values of y to automatically adjust 98 weights inversely proportional to class frequencies in the input data 99 as ``n_samples / (n_classes * np.bincount(y))``. 100 101 .. versionadded:: 0.17 102 parameter *class_weight* to automatically weight samples. 103 104 average : bool or int, default=False 105 When set to True, computes the averaged SGD weights and stores the 106 result in the ``coef_`` attribute. If set to an int greater than 1, 107 averaging will begin once the total number of samples seen reaches 108 average. So average=10 will begin averaging after seeing 10 samples. 109 110 .. versionadded:: 0.19 111 parameter *average* to use weights averaging in SGD. 112 113 Attributes 114 ---------- 115 coef_ : ndarray of shape (1, n_features) if n_classes == 2 else \ 116 (n_classes, n_features) 117 Weights assigned to the features. 118 119 intercept_ : ndarray of shape (1,) if n_classes == 2 else (n_classes,) 120 Constants in decision function. 121 122 n_features_in_ : int 123 Number of features seen during :term:`fit`. 124 125 .. versionadded:: 0.24 126 127 feature_names_in_ : ndarray of shape (`n_features_in_`,) 128 Names of features seen during :term:`fit`. Defined only when `X` 129 has feature names that are all strings. 130 131 .. versionadded:: 1.0 132 133 n_iter_ : int 134 The actual number of iterations to reach the stopping criterion. 135 For multiclass fits, it is the maximum over every binary fit. 136 137 classes_ : ndarray of shape (n_classes,) 138 The unique classes labels. 139 140 t_ : int 141 Number of weight updates performed during training. 142 Same as ``(n_iter_ * n_samples)``. 143 144 loss_function_ : callable 145 Loss function used by the algorithm. 146 147 See Also 148 -------- 149 SGDClassifier : Incrementally trained logistic regression. 150 Perceptron : Linear perceptron classifier. 151 152 References 153 ---------- 154 Online Passive-Aggressive Algorithms 155 <http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf> 156 K. Crammer, O. Dekel, J. Keshat, S. Shalev-Shwartz, Y. Singer - JMLR (2006) 157 158 Examples 159 -------- 160 >>> from sklearn.linear_model import PassiveAggressiveClassifier 161 >>> from sklearn.datasets import make_classification 162 >>> X, y = make_classification(n_features=4, random_state=0) 163 >>> clf = PassiveAggressiveClassifier(max_iter=1000, random_state=0, 164 ... tol=1e-3) 165 >>> clf.fit(X, y) 166 PassiveAggressiveClassifier(random_state=0) 167 >>> print(clf.coef_) 168 [[0.26642044 0.45070924 0.67251877 0.64185414]] 169 >>> print(clf.intercept_) 170 [1.84127814] 171 >>> print(clf.predict([[0, 0, 0, 0]])) 172 [1] 173 """ 174 175 def __init__( 176 self, 177 *, 178 C=1.0, 179 fit_intercept=True, 180 max_iter=1000, 181 tol=1e-3, 182 early_stopping=False, 183 validation_fraction=0.1, 184 n_iter_no_change=5, 185 shuffle=True, 186 verbose=0, 187 loss="hinge", 188 n_jobs=None, 189 random_state=None, 190 warm_start=False, 191 class_weight=None, 192 average=False, 193 ): 194 super().__init__( 195 penalty=None, 196 fit_intercept=fit_intercept, 197 max_iter=max_iter, 198 tol=tol, 199 early_stopping=early_stopping, 200 validation_fraction=validation_fraction, 201 n_iter_no_change=n_iter_no_change, 202 shuffle=shuffle, 203 verbose=verbose, 204 random_state=random_state, 205 eta0=1.0, 206 warm_start=warm_start, 207 class_weight=class_weight, 208 average=average, 209 n_jobs=n_jobs, 210 ) 211 212 self.C = C 213 self.loss = loss 214 215 def partial_fit(self, X, y, classes=None): 216 """Fit linear model with Passive Aggressive algorithm. 217 218 Parameters 219 ---------- 220 X : {array-like, sparse matrix} of shape (n_samples, n_features) 221 Subset of the training data. 222 223 y : array-like of shape (n_samples,) 224 Subset of the target values. 225 226 classes : ndarray of shape (n_classes,) 227 Classes across all calls to partial_fit. 228 Can be obtained by via `np.unique(y_all)`, where y_all is the 229 target vector of the entire dataset. 230 This argument is required for the first call to partial_fit 231 and can be omitted in the subsequent calls. 232 Note that y doesn't need to contain all labels in `classes`. 233 234 Returns 235 ------- 236 self : object 237 Fitted estimator. 238 """ 239 self._validate_params(for_partial_fit=True) 240 if self.class_weight == "balanced": 241 raise ValueError( 242 "class_weight 'balanced' is not supported for " 243 "partial_fit. For 'balanced' weights, use " 244 "`sklearn.utils.compute_class_weight` with " 245 "`class_weight='balanced'`. In place of y you " 246 "can use a large enough subset of the full " 247 "training set target to properly estimate the " 248 "class frequency distributions. Pass the " 249 "resulting weights as the class_weight " 250 "parameter." 251 ) 252 lr = "pa1" if self.loss == "hinge" else "pa2" 253 return self._partial_fit( 254 X, 255 y, 256 alpha=1.0, 257 C=self.C, 258 loss="hinge", 259 learning_rate=lr, 260 max_iter=1, 261 classes=classes, 262 sample_weight=None, 263 coef_init=None, 264 intercept_init=None, 265 ) 266 267 def fit(self, X, y, coef_init=None, intercept_init=None): 268 """Fit linear model with Passive Aggressive algorithm. 269 270 Parameters 271 ---------- 272 X : {array-like, sparse matrix} of shape (n_samples, n_features) 273 Training data. 274 275 y : array-like of shape (n_samples,) 276 Target values. 277 278 coef_init : ndarray of shape (n_classes, n_features) 279 The initial coefficients to warm-start the optimization. 280 281 intercept_init : ndarray of shape (n_classes,) 282 The initial intercept to warm-start the optimization. 283 284 Returns 285 ------- 286 self : object 287 Fitted estimator. 288 """ 289 self._validate_params() 290 lr = "pa1" if self.loss == "hinge" else "pa2" 291 return self._fit( 292 X, 293 y, 294 alpha=1.0, 295 C=self.C, 296 loss="hinge", 297 learning_rate=lr, 298 coef_init=coef_init, 299 intercept_init=intercept_init, 300 ) 301 302 303class PassiveAggressiveRegressor(BaseSGDRegressor): 304 """Passive Aggressive Regressor. 305 306 Read more in the :ref:`User Guide <passive_aggressive>`. 307 308 Parameters 309 ---------- 310 311 C : float, default=1.0 312 Maximum step size (regularization). Defaults to 1.0. 313 314 fit_intercept : bool, default=True 315 Whether the intercept should be estimated or not. If False, the 316 data is assumed to be already centered. Defaults to True. 317 318 max_iter : int, default=1000 319 The maximum number of passes over the training data (aka epochs). 320 It only impacts the behavior in the ``fit`` method, and not the 321 :meth:`partial_fit` method. 322 323 .. versionadded:: 0.19 324 325 tol : float or None, default=1e-3 326 The stopping criterion. If it is not None, the iterations will stop 327 when (loss > previous_loss - tol). 328 329 .. versionadded:: 0.19 330 331 early_stopping : bool, default=False 332 Whether to use early stopping to terminate training when validation. 333 score is not improving. If set to True, it will automatically set aside 334 a fraction of training data as validation and terminate 335 training when validation score is not improving by at least tol for 336 n_iter_no_change consecutive epochs. 337 338 .. versionadded:: 0.20 339 340 validation_fraction : float, default=0.1 341 The proportion of training data to set aside as validation set for 342 early stopping. Must be between 0 and 1. 343 Only used if early_stopping is True. 344 345 .. versionadded:: 0.20 346 347 n_iter_no_change : int, default=5 348 Number of iterations with no improvement to wait before early stopping. 349 350 .. versionadded:: 0.20 351 352 shuffle : bool, default=True 353 Whether or not the training data should be shuffled after each epoch. 354 355 verbose : int, default=0 356 The verbosity level. 357 358 loss : str, default="epsilon_insensitive" 359 The loss function to be used: 360 epsilon_insensitive: equivalent to PA-I in the reference paper. 361 squared_epsilon_insensitive: equivalent to PA-II in the reference 362 paper. 363 364 epsilon : float, default=0.1 365 If the difference between the current prediction and the correct label 366 is below this threshold, the model is not updated. 367 368 random_state : int, RandomState instance, default=None 369 Used to shuffle the training data, when ``shuffle`` is set to 370 ``True``. Pass an int for reproducible output across multiple 371 function calls. 372 See :term:`Glossary <random_state>`. 373 374 warm_start : bool, default=False 375 When set to True, reuse the solution of the previous call to fit as 376 initialization, otherwise, just erase the previous solution. 377 See :term:`the Glossary <warm_start>`. 378 379 Repeatedly calling fit or partial_fit when warm_start is True can 380 result in a different solution than when calling fit a single time 381 because of the way the data is shuffled. 382 383 average : bool or int, default=False 384 When set to True, computes the averaged SGD weights and stores the 385 result in the ``coef_`` attribute. If set to an int greater than 1, 386 averaging will begin once the total number of samples seen reaches 387 average. So average=10 will begin averaging after seeing 10 samples. 388 389 .. versionadded:: 0.19 390 parameter *average* to use weights averaging in SGD. 391 392 Attributes 393 ---------- 394 coef_ : array, shape = [1, n_features] if n_classes == 2 else [n_classes,\ 395 n_features] 396 Weights assigned to the features. 397 398 intercept_ : array, shape = [1] if n_classes == 2 else [n_classes] 399 Constants in decision function. 400 401 n_features_in_ : int 402 Number of features seen during :term:`fit`. 403 404 .. versionadded:: 0.24 405 406 feature_names_in_ : ndarray of shape (`n_features_in_`,) 407 Names of features seen during :term:`fit`. Defined only when `X` 408 has feature names that are all strings. 409 410 .. versionadded:: 1.0 411 412 n_iter_ : int 413 The actual number of iterations to reach the stopping criterion. 414 415 t_ : int 416 Number of weight updates performed during training. 417 Same as ``(n_iter_ * n_samples)``. 418 419 See Also 420 -------- 421 SGDRegressor : Linear model fitted by minimizing a regularized 422 empirical loss with SGD. 423 424 References 425 ---------- 426 Online Passive-Aggressive Algorithms 427 <http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf> 428 K. Crammer, O. Dekel, J. Keshat, S. Shalev-Shwartz, Y. Singer - JMLR (2006). 429 430 Examples 431 -------- 432 >>> from sklearn.linear_model import PassiveAggressiveRegressor 433 >>> from sklearn.datasets import make_regression 434 435 >>> X, y = make_regression(n_features=4, random_state=0) 436 >>> regr = PassiveAggressiveRegressor(max_iter=100, random_state=0, 437 ... tol=1e-3) 438 >>> regr.fit(X, y) 439 PassiveAggressiveRegressor(max_iter=100, random_state=0) 440 >>> print(regr.coef_) 441 [20.48736655 34.18818427 67.59122734 87.94731329] 442 >>> print(regr.intercept_) 443 [-0.02306214] 444 >>> print(regr.predict([[0, 0, 0, 0]])) 445 [-0.02306214] 446 """ 447 448 def __init__( 449 self, 450 *, 451 C=1.0, 452 fit_intercept=True, 453 max_iter=1000, 454 tol=1e-3, 455 early_stopping=False, 456 validation_fraction=0.1, 457 n_iter_no_change=5, 458 shuffle=True, 459 verbose=0, 460 loss="epsilon_insensitive", 461 epsilon=DEFAULT_EPSILON, 462 random_state=None, 463 warm_start=False, 464 average=False, 465 ): 466 super().__init__( 467 penalty=None, 468 l1_ratio=0, 469 epsilon=epsilon, 470 eta0=1.0, 471 fit_intercept=fit_intercept, 472 max_iter=max_iter, 473 tol=tol, 474 early_stopping=early_stopping, 475 validation_fraction=validation_fraction, 476 n_iter_no_change=n_iter_no_change, 477 shuffle=shuffle, 478 verbose=verbose, 479 random_state=random_state, 480 warm_start=warm_start, 481 average=average, 482 ) 483 self.C = C 484 self.loss = loss 485 486 def partial_fit(self, X, y): 487 """Fit linear model with Passive Aggressive algorithm. 488 489 Parameters 490 ---------- 491 X : {array-like, sparse matrix} of shape (n_samples, n_features) 492 Subset of training data. 493 494 y : numpy array of shape [n_samples] 495 Subset of target values. 496 497 Returns 498 ------- 499 self : object 500 Fitted estimator. 501 """ 502 self._validate_params(for_partial_fit=True) 503 lr = "pa1" if self.loss == "epsilon_insensitive" else "pa2" 504 return self._partial_fit( 505 X, 506 y, 507 alpha=1.0, 508 C=self.C, 509 loss="epsilon_insensitive", 510 learning_rate=lr, 511 max_iter=1, 512 sample_weight=None, 513 coef_init=None, 514 intercept_init=None, 515 ) 516 517 def fit(self, X, y, coef_init=None, intercept_init=None): 518 """Fit linear model with Passive Aggressive algorithm. 519 520 Parameters 521 ---------- 522 X : {array-like, sparse matrix} of shape (n_samples, n_features) 523 Training data. 524 525 y : numpy array of shape [n_samples] 526 Target values. 527 528 coef_init : array, shape = [n_features] 529 The initial coefficients to warm-start the optimization. 530 531 intercept_init : array, shape = [1] 532 The initial intercept to warm-start the optimization. 533 534 Returns 535 ------- 536 self : object 537 Fitted estimator. 538 """ 539 self._validate_params() 540 lr = "pa1" if self.loss == "epsilon_insensitive" else "pa2" 541 return self._fit( 542 X, 543 y, 544 alpha=1.0, 545 C=self.C, 546 loss="epsilon_insensitive", 547 learning_rate=lr, 548 coef_init=coef_init, 549 intercept_init=intercept_init, 550 ) 551