1# -*- coding: utf-8 -*-
2"""
3Authors:    Josef Perktold, Skipper Seabold, Denis A. Engemann
4"""
5from statsmodels.compat.python import lrange
6
7import numpy as np
8
9from statsmodels.graphics.plottools import rainbow
10import statsmodels.graphics.utils as utils
11
12
13def interaction_plot(x, trace, response, func=np.mean, ax=None, plottype='b',
14                     xlabel=None, ylabel=None, colors=None, markers=None,
15                     linestyles=None, legendloc='best', legendtitle=None,
16                     **kwargs):
17    """
18    Interaction plot for factor level statistics.
19
20    Note. If categorial factors are supplied levels will be internally
21    recoded to integers. This ensures matplotlib compatibility. Uses
22    a DataFrame to calculate an `aggregate` statistic for each level of the
23    factor or group given by `trace`.
24
25    Parameters
26    ----------
27    x : array_like
28        The `x` factor levels constitute the x-axis. If a `pandas.Series` is
29        given its name will be used in `xlabel` if `xlabel` is None.
30    trace : array_like
31        The `trace` factor levels will be drawn as lines in the plot.
32        If `trace` is a `pandas.Series` its name will be used as the
33        `legendtitle` if `legendtitle` is None.
34    response : array_like
35        The reponse or dependent variable. If a `pandas.Series` is given
36        its name will be used in `ylabel` if `ylabel` is None.
37    func : function
38        Anything accepted by `pandas.DataFrame.aggregate`. This is applied to
39        the response variable grouped by the trace levels.
40    ax : axes, optional
41        Matplotlib axes instance
42    plottype : str {'line', 'scatter', 'both'}, optional
43        The type of plot to return. Can be 'l', 's', or 'b'
44    xlabel : str, optional
45        Label to use for `x`. Default is 'X'. If `x` is a `pandas.Series` it
46        will use the series names.
47    ylabel : str, optional
48        Label to use for `response`. Default is 'func of response'. If
49        `response` is a `pandas.Series` it will use the series names.
50    colors : list, optional
51        If given, must have length == number of levels in trace.
52    markers : list, optional
53        If given, must have length == number of levels in trace
54    linestyles : list, optional
55        If given, must have length == number of levels in trace.
56    legendloc : {None, str, int}
57        Location passed to the legend command.
58    legendtitle : {None, str}
59        Title of the legend.
60    **kwargs
61        These will be passed to the plot command used either plot or scatter.
62        If you want to control the overall plotting options, use kwargs.
63
64    Returns
65    -------
66    Figure
67        The figure given by `ax.figure` or a new instance.
68
69    Examples
70    --------
71    >>> import numpy as np
72    >>> np.random.seed(12345)
73    >>> weight = np.random.randint(1,4,size=60)
74    >>> duration = np.random.randint(1,3,size=60)
75    >>> days = np.log(np.random.randint(1,30, size=60))
76    >>> fig = interaction_plot(weight, duration, days,
77    ...             colors=['red','blue'], markers=['D','^'], ms=10)
78    >>> import matplotlib.pyplot as plt
79    >>> plt.show()
80
81    .. plot::
82
83       import numpy as np
84       from statsmodels.graphics.factorplots import interaction_plot
85       np.random.seed(12345)
86       weight = np.random.randint(1,4,size=60)
87       duration = np.random.randint(1,3,size=60)
88       days = np.log(np.random.randint(1,30, size=60))
89       fig = interaction_plot(weight, duration, days,
90                   colors=['red','blue'], markers=['D','^'], ms=10)
91       import matplotlib.pyplot as plt
92       #plt.show()
93    """
94
95    from pandas import DataFrame
96    fig, ax = utils.create_mpl_ax(ax)
97
98    response_name = ylabel or getattr(response, 'name', 'response')
99    ylabel = '%s of %s' % (func.__name__, response_name)
100    xlabel = xlabel or getattr(x, 'name', 'X')
101    legendtitle = legendtitle or getattr(trace, 'name', 'Trace')
102
103    ax.set_ylabel(ylabel)
104    ax.set_xlabel(xlabel)
105
106    x_values = x_levels = None
107    if isinstance(x[0], str):
108        x_levels = [l for l in np.unique(x)]
109        x_values = lrange(len(x_levels))
110        x = _recode(x, dict(zip(x_levels, x_values)))
111
112    data = DataFrame(dict(x=x, trace=trace, response=response))
113    plot_data = data.groupby(['trace', 'x']).aggregate(func).reset_index()
114
115    # return data
116    # check plot args
117    n_trace = len(plot_data['trace'].unique())
118
119    linestyles = ['-'] * n_trace if linestyles is None else linestyles
120    markers = ['.'] * n_trace if markers is None else markers
121    colors = rainbow(n_trace) if colors is None else colors
122
123    if len(linestyles) != n_trace:
124        raise ValueError("Must be a linestyle for each trace level")
125    if len(markers) != n_trace:
126        raise ValueError("Must be a marker for each trace level")
127    if len(colors) != n_trace:
128        raise ValueError("Must be a color for each trace level")
129
130    if plottype == 'both' or plottype == 'b':
131        for i, (values, group) in enumerate(plot_data.groupby(['trace'])):
132            # trace label
133            label = str(group['trace'].values[0])
134            ax.plot(group['x'], group['response'], color=colors[i],
135                    marker=markers[i], label=label,
136                    linestyle=linestyles[i], **kwargs)
137    elif plottype == 'line' or plottype == 'l':
138        for i, (values, group) in enumerate(plot_data.groupby(['trace'])):
139            # trace label
140            label = str(group['trace'].values[0])
141            ax.plot(group['x'], group['response'], color=colors[i],
142                    label=label, linestyle=linestyles[i], **kwargs)
143    elif plottype == 'scatter' or plottype == 's':
144        for i, (values, group) in enumerate(plot_data.groupby(['trace'])):
145            # trace label
146            label = str(group['trace'].values[0])
147            ax.scatter(group['x'], group['response'], color=colors[i],
148                    label=label, marker=markers[i], **kwargs)
149
150    else:
151        raise ValueError("Plot type %s not understood" % plottype)
152    ax.legend(loc=legendloc, title=legendtitle)
153    ax.margins(.1)
154
155    if all([x_levels, x_values]):
156        ax.set_xticks(x_values)
157        ax.set_xticklabels(x_levels)
158    return fig
159
160
161def _recode(x, levels):
162    """ Recode categorial data to int factor.
163
164    Parameters
165    ----------
166    x : array_like
167        array like object supporting with numpy array methods of categorially
168        coded data.
169    levels : dict
170        mapping of labels to integer-codings
171
172    Returns
173    -------
174    out : instance numpy.ndarray
175    """
176    from pandas import Series
177    name = None
178    index = None
179
180    if isinstance(x, Series):
181        name = x.name
182        index = x.index
183        x = x.values
184
185    if x.dtype.type not in [np.str_, np.object_]:
186        raise ValueError('This is not a categorial factor.'
187                         ' Array of str type required.')
188
189    elif not isinstance(levels, dict):
190        raise ValueError('This is not a valid value for levels.'
191                         ' Dict required.')
192
193    elif not (np.unique(x) == np.unique(list(levels.keys()))).all():
194        raise ValueError('The levels do not match the array values.')
195
196    else:
197        out = np.empty(x.shape[0], dtype=int)
198        for level, coding in levels.items():
199            out[x == level] = coding
200
201        if name:
202            out = Series(out, name=name, index=index)
203
204        return out
205