1""" 2================================================================ 3Permutation Importance vs Random Forest Feature Importance (MDI) 4================================================================ 5 6In this example, we will compare the impurity-based feature importance of 7:class:`~sklearn.ensemble.RandomForestClassifier` with the 8permutation importance on the titanic dataset using 9:func:`~sklearn.inspection.permutation_importance`. We will show that the 10impurity-based feature importance can inflate the importance of numerical 11features. 12 13Furthermore, the impurity-based feature importance of random forests suffers 14from being computed on statistics derived from the training dataset: the 15importances can be high even for features that are not predictive of the target 16variable, as long as the model has the capacity to use them to overfit. 17 18This example shows how to use Permutation Importances as an alternative that 19can mitigate those limitations. 20 21.. topic:: References: 22 23 [1] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32, 24 2001. https://doi.org/10.1023/A:1010933404324 25 26""" 27 28import matplotlib.pyplot as plt 29import numpy as np 30 31from sklearn.datasets import fetch_openml 32from sklearn.ensemble import RandomForestClassifier 33from sklearn.impute import SimpleImputer 34from sklearn.inspection import permutation_importance 35from sklearn.compose import ColumnTransformer 36from sklearn.model_selection import train_test_split 37from sklearn.pipeline import Pipeline 38from sklearn.preprocessing import OneHotEncoder 39 40 41# %% 42# Data Loading and Feature Engineering 43# ------------------------------------ 44# Let's use pandas to load a copy of the titanic dataset. The following shows 45# how to apply separate preprocessing on numerical and categorical features. 46# 47# We further include two random variables that are not correlated in any way 48# with the target variable (``survived``): 49# 50# - ``random_num`` is a high cardinality numerical variable (as many unique 51# values as records). 52# - ``random_cat`` is a low cardinality categorical variable (3 possible 53# values). 54X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True) 55rng = np.random.RandomState(seed=42) 56X["random_cat"] = rng.randint(3, size=X.shape[0]) 57X["random_num"] = rng.randn(X.shape[0]) 58 59categorical_columns = ["pclass", "sex", "embarked", "random_cat"] 60numerical_columns = ["age", "sibsp", "parch", "fare", "random_num"] 61 62X = X[categorical_columns + numerical_columns] 63 64X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=42) 65 66categorical_encoder = OneHotEncoder(handle_unknown="ignore") 67numerical_pipe = Pipeline([("imputer", SimpleImputer(strategy="mean"))]) 68 69preprocessing = ColumnTransformer( 70 [ 71 ("cat", categorical_encoder, categorical_columns), 72 ("num", numerical_pipe, numerical_columns), 73 ] 74) 75 76rf = Pipeline( 77 [ 78 ("preprocess", preprocessing), 79 ("classifier", RandomForestClassifier(random_state=42)), 80 ] 81) 82rf.fit(X_train, y_train) 83 84# %% 85# Accuracy of the Model 86# --------------------- 87# Prior to inspecting the feature importances, it is important to check that 88# the model predictive performance is high enough. Indeed there would be little 89# interest of inspecting the important features of a non-predictive model. 90# 91# Here one can observe that the train accuracy is very high (the forest model 92# has enough capacity to completely memorize the training set) but it can still 93# generalize well enough to the test set thanks to the built-in bagging of 94# random forests. 95# 96# It might be possible to trade some accuracy on the training set for a 97# slightly better accuracy on the test set by limiting the capacity of the 98# trees (for instance by setting ``min_samples_leaf=5`` or 99# ``min_samples_leaf=10``) so as to limit overfitting while not introducing too 100# much underfitting. 101# 102# However let's keep our high capacity random forest model for now so as to 103# illustrate some pitfalls with feature importance on variables with many 104# unique values. 105print("RF train accuracy: %0.3f" % rf.score(X_train, y_train)) 106print("RF test accuracy: %0.3f" % rf.score(X_test, y_test)) 107 108 109# %% 110# Tree's Feature Importance from Mean Decrease in Impurity (MDI) 111# -------------------------------------------------------------- 112# The impurity-based feature importance ranks the numerical features to be the 113# most important features. As a result, the non-predictive ``random_num`` 114# variable is ranked the most important! 115# 116# This problem stems from two limitations of impurity-based feature 117# importances: 118# 119# - impurity-based importances are biased towards high cardinality features; 120# - impurity-based importances are computed on training set statistics and 121# therefore do not reflect the ability of feature to be useful to make 122# predictions that generalize to the test set (when the model has enough 123# capacity). 124ohe = rf.named_steps["preprocess"].named_transformers_["cat"] 125feature_names = ohe.get_feature_names_out(categorical_columns) 126feature_names = np.r_[feature_names, numerical_columns] 127 128tree_feature_importances = rf.named_steps["classifier"].feature_importances_ 129sorted_idx = tree_feature_importances.argsort() 130 131y_ticks = np.arange(0, len(feature_names)) 132fig, ax = plt.subplots() 133ax.barh(y_ticks, tree_feature_importances[sorted_idx]) 134ax.set_yticks(y_ticks) 135ax.set_yticklabels(feature_names[sorted_idx]) 136ax.set_title("Random Forest Feature Importances (MDI)") 137fig.tight_layout() 138plt.show() 139 140 141# %% 142# As an alternative, the permutation importances of ``rf`` are computed on a 143# held out test set. This shows that the low cardinality categorical feature, 144# ``sex`` is the most important feature. 145# 146# Also note that both random features have very low importances (close to 0) as 147# expected. 148result = permutation_importance( 149 rf, X_test, y_test, n_repeats=10, random_state=42, n_jobs=2 150) 151sorted_idx = result.importances_mean.argsort() 152 153fig, ax = plt.subplots() 154ax.boxplot( 155 result.importances[sorted_idx].T, vert=False, labels=X_test.columns[sorted_idx] 156) 157ax.set_title("Permutation Importances (test set)") 158fig.tight_layout() 159plt.show() 160 161# %% 162# It is also possible to compute the permutation importances on the training 163# set. This reveals that ``random_num`` gets a significantly higher importance 164# ranking than when computed on the test set. The difference between those two 165# plots is a confirmation that the RF model has enough capacity to use that 166# random numerical feature to overfit. You can further confirm this by 167# re-running this example with constrained RF with min_samples_leaf=10. 168result = permutation_importance( 169 rf, X_train, y_train, n_repeats=10, random_state=42, n_jobs=2 170) 171sorted_idx = result.importances_mean.argsort() 172 173fig, ax = plt.subplots() 174ax.boxplot( 175 result.importances[sorted_idx].T, vert=False, labels=X_train.columns[sorted_idx] 176) 177ax.set_title("Permutation Importances (train set)") 178fig.tight_layout() 179plt.show() 180