1"""
2==========================
3Plotting Validation Curves
4==========================
5
6In this plot you can see the training scores and validation scores of an SVM
7for different values of the kernel parameter gamma. For very low values of
8gamma, you can see that both the training score and the validation score are
9low. This is called underfitting. Medium values of gamma will result in high
10values for both scores, i.e. the classifier is performing fairly well. If gamma
11is too high, the classifier will overfit, which means that the training score
12is good but the validation score is poor.
13
14"""
15
16import matplotlib.pyplot as plt
17import numpy as np
18
19from sklearn.datasets import load_digits
20from sklearn.svm import SVC
21from sklearn.model_selection import validation_curve
22
23X, y = load_digits(return_X_y=True)
24subset_mask = np.isin(y, [1, 2])  # binary classification: 1 vs 2
25X, y = X[subset_mask], y[subset_mask]
26
27param_range = np.logspace(-6, -1, 5)
28train_scores, test_scores = validation_curve(
29    SVC(),
30    X,
31    y,
32    param_name="gamma",
33    param_range=param_range,
34    scoring="accuracy",
35    n_jobs=2,
36)
37train_scores_mean = np.mean(train_scores, axis=1)
38train_scores_std = np.std(train_scores, axis=1)
39test_scores_mean = np.mean(test_scores, axis=1)
40test_scores_std = np.std(test_scores, axis=1)
41
42plt.title("Validation Curve with SVM")
43plt.xlabel(r"$\gamma$")
44plt.ylabel("Score")
45plt.ylim(0.0, 1.1)
46lw = 2
47plt.semilogx(
48    param_range, train_scores_mean, label="Training score", color="darkorange", lw=lw
49)
50plt.fill_between(
51    param_range,
52    train_scores_mean - train_scores_std,
53    train_scores_mean + train_scores_std,
54    alpha=0.2,
55    color="darkorange",
56    lw=lw,
57)
58plt.semilogx(
59    param_range, test_scores_mean, label="Cross-validation score", color="navy", lw=lw
60)
61plt.fill_between(
62    param_range,
63    test_scores_mean - test_scores_std,
64    test_scores_mean + test_scores_std,
65    alpha=0.2,
66    color="navy",
67    lw=lw,
68)
69plt.legend(loc="best")
70plt.show()
71