1""" 2========================== 3Plotting Validation Curves 4========================== 5 6In this plot you can see the training scores and validation scores of an SVM 7for different values of the kernel parameter gamma. For very low values of 8gamma, you can see that both the training score and the validation score are 9low. This is called underfitting. Medium values of gamma will result in high 10values for both scores, i.e. the classifier is performing fairly well. If gamma 11is too high, the classifier will overfit, which means that the training score 12is good but the validation score is poor. 13 14""" 15 16import matplotlib.pyplot as plt 17import numpy as np 18 19from sklearn.datasets import load_digits 20from sklearn.svm import SVC 21from sklearn.model_selection import validation_curve 22 23X, y = load_digits(return_X_y=True) 24subset_mask = np.isin(y, [1, 2]) # binary classification: 1 vs 2 25X, y = X[subset_mask], y[subset_mask] 26 27param_range = np.logspace(-6, -1, 5) 28train_scores, test_scores = validation_curve( 29 SVC(), 30 X, 31 y, 32 param_name="gamma", 33 param_range=param_range, 34 scoring="accuracy", 35 n_jobs=2, 36) 37train_scores_mean = np.mean(train_scores, axis=1) 38train_scores_std = np.std(train_scores, axis=1) 39test_scores_mean = np.mean(test_scores, axis=1) 40test_scores_std = np.std(test_scores, axis=1) 41 42plt.title("Validation Curve with SVM") 43plt.xlabel(r"$\gamma$") 44plt.ylabel("Score") 45plt.ylim(0.0, 1.1) 46lw = 2 47plt.semilogx( 48 param_range, train_scores_mean, label="Training score", color="darkorange", lw=lw 49) 50plt.fill_between( 51 param_range, 52 train_scores_mean - train_scores_std, 53 train_scores_mean + train_scores_std, 54 alpha=0.2, 55 color="darkorange", 56 lw=lw, 57) 58plt.semilogx( 59 param_range, test_scores_mean, label="Cross-validation score", color="navy", lw=lw 60) 61plt.fill_between( 62 param_range, 63 test_scores_mean - test_scores_std, 64 test_scores_mean + test_scores_std, 65 alpha=0.2, 66 color="navy", 67 lw=lw, 68) 69plt.legend(loc="best") 70plt.show() 71