1# -*- coding: utf-8 -*- 2""" 3Authors: Josef Perktold, Skipper Seabold, Denis A. Engemann 4""" 5from statsmodels.compat.python import lrange 6 7import numpy as np 8 9from statsmodels.graphics.plottools import rainbow 10import statsmodels.graphics.utils as utils 11 12 13def interaction_plot(x, trace, response, func=np.mean, ax=None, plottype='b', 14 xlabel=None, ylabel=None, colors=None, markers=None, 15 linestyles=None, legendloc='best', legendtitle=None, 16 **kwargs): 17 """ 18 Interaction plot for factor level statistics. 19 20 Note. If categorial factors are supplied levels will be internally 21 recoded to integers. This ensures matplotlib compatibility. Uses 22 a DataFrame to calculate an `aggregate` statistic for each level of the 23 factor or group given by `trace`. 24 25 Parameters 26 ---------- 27 x : array_like 28 The `x` factor levels constitute the x-axis. If a `pandas.Series` is 29 given its name will be used in `xlabel` if `xlabel` is None. 30 trace : array_like 31 The `trace` factor levels will be drawn as lines in the plot. 32 If `trace` is a `pandas.Series` its name will be used as the 33 `legendtitle` if `legendtitle` is None. 34 response : array_like 35 The reponse or dependent variable. If a `pandas.Series` is given 36 its name will be used in `ylabel` if `ylabel` is None. 37 func : function 38 Anything accepted by `pandas.DataFrame.aggregate`. This is applied to 39 the response variable grouped by the trace levels. 40 ax : axes, optional 41 Matplotlib axes instance 42 plottype : str {'line', 'scatter', 'both'}, optional 43 The type of plot to return. Can be 'l', 's', or 'b' 44 xlabel : str, optional 45 Label to use for `x`. Default is 'X'. If `x` is a `pandas.Series` it 46 will use the series names. 47 ylabel : str, optional 48 Label to use for `response`. Default is 'func of response'. If 49 `response` is a `pandas.Series` it will use the series names. 50 colors : list, optional 51 If given, must have length == number of levels in trace. 52 markers : list, optional 53 If given, must have length == number of levels in trace 54 linestyles : list, optional 55 If given, must have length == number of levels in trace. 56 legendloc : {None, str, int} 57 Location passed to the legend command. 58 legendtitle : {None, str} 59 Title of the legend. 60 **kwargs 61 These will be passed to the plot command used either plot or scatter. 62 If you want to control the overall plotting options, use kwargs. 63 64 Returns 65 ------- 66 Figure 67 The figure given by `ax.figure` or a new instance. 68 69 Examples 70 -------- 71 >>> import numpy as np 72 >>> np.random.seed(12345) 73 >>> weight = np.random.randint(1,4,size=60) 74 >>> duration = np.random.randint(1,3,size=60) 75 >>> days = np.log(np.random.randint(1,30, size=60)) 76 >>> fig = interaction_plot(weight, duration, days, 77 ... colors=['red','blue'], markers=['D','^'], ms=10) 78 >>> import matplotlib.pyplot as plt 79 >>> plt.show() 80 81 .. plot:: 82 83 import numpy as np 84 from statsmodels.graphics.factorplots import interaction_plot 85 np.random.seed(12345) 86 weight = np.random.randint(1,4,size=60) 87 duration = np.random.randint(1,3,size=60) 88 days = np.log(np.random.randint(1,30, size=60)) 89 fig = interaction_plot(weight, duration, days, 90 colors=['red','blue'], markers=['D','^'], ms=10) 91 import matplotlib.pyplot as plt 92 #plt.show() 93 """ 94 95 from pandas import DataFrame 96 fig, ax = utils.create_mpl_ax(ax) 97 98 response_name = ylabel or getattr(response, 'name', 'response') 99 ylabel = '%s of %s' % (func.__name__, response_name) 100 xlabel = xlabel or getattr(x, 'name', 'X') 101 legendtitle = legendtitle or getattr(trace, 'name', 'Trace') 102 103 ax.set_ylabel(ylabel) 104 ax.set_xlabel(xlabel) 105 106 x_values = x_levels = None 107 if isinstance(x[0], str): 108 x_levels = [l for l in np.unique(x)] 109 x_values = lrange(len(x_levels)) 110 x = _recode(x, dict(zip(x_levels, x_values))) 111 112 data = DataFrame(dict(x=x, trace=trace, response=response)) 113 plot_data = data.groupby(['trace', 'x']).aggregate(func).reset_index() 114 115 # return data 116 # check plot args 117 n_trace = len(plot_data['trace'].unique()) 118 119 linestyles = ['-'] * n_trace if linestyles is None else linestyles 120 markers = ['.'] * n_trace if markers is None else markers 121 colors = rainbow(n_trace) if colors is None else colors 122 123 if len(linestyles) != n_trace: 124 raise ValueError("Must be a linestyle for each trace level") 125 if len(markers) != n_trace: 126 raise ValueError("Must be a marker for each trace level") 127 if len(colors) != n_trace: 128 raise ValueError("Must be a color for each trace level") 129 130 if plottype == 'both' or plottype == 'b': 131 for i, (values, group) in enumerate(plot_data.groupby(['trace'])): 132 # trace label 133 label = str(group['trace'].values[0]) 134 ax.plot(group['x'], group['response'], color=colors[i], 135 marker=markers[i], label=label, 136 linestyle=linestyles[i], **kwargs) 137 elif plottype == 'line' or plottype == 'l': 138 for i, (values, group) in enumerate(plot_data.groupby(['trace'])): 139 # trace label 140 label = str(group['trace'].values[0]) 141 ax.plot(group['x'], group['response'], color=colors[i], 142 label=label, linestyle=linestyles[i], **kwargs) 143 elif plottype == 'scatter' or plottype == 's': 144 for i, (values, group) in enumerate(plot_data.groupby(['trace'])): 145 # trace label 146 label = str(group['trace'].values[0]) 147 ax.scatter(group['x'], group['response'], color=colors[i], 148 label=label, marker=markers[i], **kwargs) 149 150 else: 151 raise ValueError("Plot type %s not understood" % plottype) 152 ax.legend(loc=legendloc, title=legendtitle) 153 ax.margins(.1) 154 155 if all([x_levels, x_values]): 156 ax.set_xticks(x_values) 157 ax.set_xticklabels(x_levels) 158 return fig 159 160 161def _recode(x, levels): 162 """ Recode categorial data to int factor. 163 164 Parameters 165 ---------- 166 x : array_like 167 array like object supporting with numpy array methods of categorially 168 coded data. 169 levels : dict 170 mapping of labels to integer-codings 171 172 Returns 173 ------- 174 out : instance numpy.ndarray 175 """ 176 from pandas import Series 177 name = None 178 index = None 179 180 if isinstance(x, Series): 181 name = x.name 182 index = x.index 183 x = x.values 184 185 if x.dtype.type not in [np.str_, np.object_]: 186 raise ValueError('This is not a categorial factor.' 187 ' Array of str type required.') 188 189 elif not isinstance(levels, dict): 190 raise ValueError('This is not a valid value for levels.' 191 ' Dict required.') 192 193 elif not (np.unique(x) == np.unique(list(levels.keys()))).all(): 194 raise ValueError('The levels do not match the array values.') 195 196 else: 197 out = np.empty(x.shape[0], dtype=int) 198 for level, coding in levels.items(): 199 out[x == level] = coding 200 201 if name: 202 out = Series(out, name=name, index=index) 203 204 return out 205