1""" 2=================================== 3Simple 1D Kernel Density Estimation 4=================================== 5This example uses the :class:`~sklearn.neighbors.KernelDensity` class to 6demonstrate the principles of Kernel Density Estimation in one dimension. 7 8The first plot shows one of the problems with using histograms to visualize 9the density of points in 1D. Intuitively, a histogram can be thought of as a 10scheme in which a unit "block" is stacked above each point on a regular grid. 11As the top two panels show, however, the choice of gridding for these blocks 12can lead to wildly divergent ideas about the underlying shape of the density 13distribution. If we instead center each block on the point it represents, we 14get the estimate shown in the bottom left panel. This is a kernel density 15estimation with a "top hat" kernel. This idea can be generalized to other 16kernel shapes: the bottom-right panel of the first figure shows a Gaussian 17kernel density estimate over the same distribution. 18 19Scikit-learn implements efficient kernel density estimation using either 20a Ball Tree or KD Tree structure, through the 21:class:`~sklearn.neighbors.KernelDensity` estimator. The available kernels 22are shown in the second figure of this example. 23 24The third figure compares kernel density estimates for a distribution of 100 25samples in 1 dimension. Though this example uses 1D distributions, kernel 26density estimation is easily and efficiently extensible to higher dimensions 27as well. 28 29""" 30 31# Author: Jake Vanderplas <jakevdp@cs.washington.edu> 32# 33import numpy as np 34import matplotlib 35import matplotlib.pyplot as plt 36from scipy.stats import norm 37from sklearn.neighbors import KernelDensity 38from sklearn.utils.fixes import parse_version 39 40# `normed` is being deprecated in favor of `density` in histograms 41if parse_version(matplotlib.__version__) >= parse_version("2.1"): 42 density_param = {"density": True} 43else: 44 density_param = {"normed": True} 45 46# ---------------------------------------------------------------------- 47# Plot the progression of histograms to kernels 48np.random.seed(1) 49N = 20 50X = np.concatenate( 51 (np.random.normal(0, 1, int(0.3 * N)), np.random.normal(5, 1, int(0.7 * N))) 52)[:, np.newaxis] 53X_plot = np.linspace(-5, 10, 1000)[:, np.newaxis] 54bins = np.linspace(-5, 10, 10) 55 56fig, ax = plt.subplots(2, 2, sharex=True, sharey=True) 57fig.subplots_adjust(hspace=0.05, wspace=0.05) 58 59# histogram 1 60ax[0, 0].hist(X[:, 0], bins=bins, fc="#AAAAFF", **density_param) 61ax[0, 0].text(-3.5, 0.31, "Histogram") 62 63# histogram 2 64ax[0, 1].hist(X[:, 0], bins=bins + 0.75, fc="#AAAAFF", **density_param) 65ax[0, 1].text(-3.5, 0.31, "Histogram, bins shifted") 66 67# tophat KDE 68kde = KernelDensity(kernel="tophat", bandwidth=0.75).fit(X) 69log_dens = kde.score_samples(X_plot) 70ax[1, 0].fill(X_plot[:, 0], np.exp(log_dens), fc="#AAAAFF") 71ax[1, 0].text(-3.5, 0.31, "Tophat Kernel Density") 72 73# Gaussian KDE 74kde = KernelDensity(kernel="gaussian", bandwidth=0.75).fit(X) 75log_dens = kde.score_samples(X_plot) 76ax[1, 1].fill(X_plot[:, 0], np.exp(log_dens), fc="#AAAAFF") 77ax[1, 1].text(-3.5, 0.31, "Gaussian Kernel Density") 78 79for axi in ax.ravel(): 80 axi.plot(X[:, 0], np.full(X.shape[0], -0.01), "+k") 81 axi.set_xlim(-4, 9) 82 axi.set_ylim(-0.02, 0.34) 83 84for axi in ax[:, 0]: 85 axi.set_ylabel("Normalized Density") 86 87for axi in ax[1, :]: 88 axi.set_xlabel("x") 89 90# ---------------------------------------------------------------------- 91# Plot all available kernels 92X_plot = np.linspace(-6, 6, 1000)[:, None] 93X_src = np.zeros((1, 1)) 94 95fig, ax = plt.subplots(2, 3, sharex=True, sharey=True) 96fig.subplots_adjust(left=0.05, right=0.95, hspace=0.05, wspace=0.05) 97 98 99def format_func(x, loc): 100 if x == 0: 101 return "0" 102 elif x == 1: 103 return "h" 104 elif x == -1: 105 return "-h" 106 else: 107 return "%ih" % x 108 109 110for i, kernel in enumerate( 111 ["gaussian", "tophat", "epanechnikov", "exponential", "linear", "cosine"] 112): 113 axi = ax.ravel()[i] 114 log_dens = KernelDensity(kernel=kernel).fit(X_src).score_samples(X_plot) 115 axi.fill(X_plot[:, 0], np.exp(log_dens), "-k", fc="#AAAAFF") 116 axi.text(-2.6, 0.95, kernel) 117 118 axi.xaxis.set_major_formatter(plt.FuncFormatter(format_func)) 119 axi.xaxis.set_major_locator(plt.MultipleLocator(1)) 120 axi.yaxis.set_major_locator(plt.NullLocator()) 121 122 axi.set_ylim(0, 1.05) 123 axi.set_xlim(-2.9, 2.9) 124 125ax[0, 1].set_title("Available Kernels") 126 127# ---------------------------------------------------------------------- 128# Plot a 1D density example 129N = 100 130np.random.seed(1) 131X = np.concatenate( 132 (np.random.normal(0, 1, int(0.3 * N)), np.random.normal(5, 1, int(0.7 * N))) 133)[:, np.newaxis] 134 135X_plot = np.linspace(-5, 10, 1000)[:, np.newaxis] 136 137true_dens = 0.3 * norm(0, 1).pdf(X_plot[:, 0]) + 0.7 * norm(5, 1).pdf(X_plot[:, 0]) 138 139fig, ax = plt.subplots() 140ax.fill(X_plot[:, 0], true_dens, fc="black", alpha=0.2, label="input distribution") 141colors = ["navy", "cornflowerblue", "darkorange"] 142kernels = ["gaussian", "tophat", "epanechnikov"] 143lw = 2 144 145for color, kernel in zip(colors, kernels): 146 kde = KernelDensity(kernel=kernel, bandwidth=0.5).fit(X) 147 log_dens = kde.score_samples(X_plot) 148 ax.plot( 149 X_plot[:, 0], 150 np.exp(log_dens), 151 color=color, 152 lw=lw, 153 linestyle="-", 154 label="kernel = '{0}'".format(kernel), 155 ) 156 157ax.text(6, 0.38, "N={0} points".format(N)) 158 159ax.legend(loc="upper left") 160ax.plot(X[:, 0], -0.005 - 0.01 * np.random.random(X.shape[0]), "+k") 161 162ax.set_xlim(-4, 9) 163ax.set_ylim(-0.02, 0.4) 164plt.show() 165