1import numpy as np
2
3from . import utils
4
5
6def dot_plot(points, intervals=None, lines=None, sections=None,
7             styles=None, marker_props=None, line_props=None,
8             split_names=None, section_order=None, line_order=None,
9             stacked=False, styles_order=None, striped=False,
10             horizontal=True, show_names="both",
11             fmt_left_name=None, fmt_right_name=None,
12             show_section_titles=None, ax=None):
13    """
14    Dot plotting (also known as forest and blobbogram).
15
16    Produce a dotplot similar in style to those in Cleveland's
17    "Visualizing Data" book ([1]_).  These are also known as "forest plots".
18
19    Parameters
20    ----------
21    points : array_like
22        The quantitative values to be plotted as markers.
23    intervals : array_like
24        The intervals to be plotted around the points.  The elements
25        of `intervals` are either scalars or sequences of length 2.  A
26        scalar indicates the half width of a symmetric interval.  A
27        sequence of length 2 contains the left and right half-widths
28        (respectively) of a nonsymmetric interval.  If None, no
29        intervals are drawn.
30    lines : array_like
31        A grouping variable indicating which points/intervals are
32        drawn on a common line.  If None, each point/interval appears
33        on its own line.
34    sections : array_like
35        A grouping variable indicating which lines are grouped into
36        sections.  If None, everything is drawn in a single section.
37    styles : array_like
38        A grouping label defining the plotting style of the markers
39        and intervals.
40    marker_props : dict
41        A dictionary mapping style codes (the values in `styles`) to
42        dictionaries defining key/value pairs to be passed as keyword
43        arguments to `plot` when plotting markers.  Useful keyword
44        arguments are "color", "marker", and "ms" (marker size).
45    line_props : dict
46        A dictionary mapping style codes (the values in `styles`) to
47        dictionaries defining key/value pairs to be passed as keyword
48        arguments to `plot` when plotting interval lines.  Useful
49        keyword arguments are "color", "linestyle", "solid_capstyle",
50        and "linewidth".
51    split_names : str
52        If not None, this is used to split the values of `lines` into
53        substrings that are drawn in the left and right margins,
54        respectively.  If None, the values of `lines` are drawn in the
55        left margin.
56    section_order : array_like
57        The section labels in the order in which they appear in the
58        dotplot.
59    line_order : array_like
60        The line labels in the order in which they appear in the
61        dotplot.
62    stacked : bool
63        If True, when multiple points or intervals are drawn on the
64        same line, they are offset from each other.
65    styles_order : array_like
66        If stacked=True, this is the order in which the point styles
67        on a given line are drawn from top to bottom (if horizontal
68        is True) or from left to right (if horizontal is False).  If
69        None (default), the order is lexical.
70    striped : bool
71        If True, every other line is enclosed in a shaded box.
72    horizontal : bool
73        If True (default), the lines are drawn horizontally, otherwise
74        they are drawn vertically.
75    show_names : str
76        Determines whether labels (names) are shown in the left and/or
77        right margins (top/bottom margins if `horizontal` is True).
78        If `both`, labels are drawn in both margins, if 'left', labels
79        are drawn in the left or top margin.  If `right`, labels are
80        drawn in the right or bottom margin.
81    fmt_left_name : callable
82        The left/top margin names are passed through this function
83        before drawing on the plot.
84    fmt_right_name : callable
85        The right/bottom marginnames are passed through this function
86        before drawing on the plot.
87    show_section_titles : bool or None
88        If None, section titles are drawn only if there is more than
89        one section.  If False/True, section titles are never/always
90        drawn, respectively.
91    ax : matplotlib.axes
92        The axes on which the dotplot is drawn.  If None, a new axes
93        is created.
94
95    Returns
96    -------
97    fig : Figure
98        The figure given by `ax.figure` or a new instance.
99
100    Notes
101    -----
102    `points`, `intervals`, `lines`, `sections`, `styles` must all have
103    the same length whenever present.
104
105    References
106    ----------
107    .. [1] Cleveland, William S. (1993). "Visualizing Data". Hobart Press.
108    .. [2] Jacoby, William G. (2006) "The Dot Plot: A Graphical Display
109       for Labeled Quantitative Values." The Political Methodologist
110       14(1): 6-14.
111
112    Examples
113    --------
114    This is a simple dotplot with one point per line:
115
116    >>> dot_plot(points=point_values)
117
118    This dotplot has labels on the lines (if elements in
119    `label_values` are repeated, the corresponding points appear on
120    the same line):
121
122    >>> dot_plot(points=point_values, lines=label_values)
123    """
124
125    import matplotlib.transforms as transforms
126
127    fig, ax = utils.create_mpl_ax(ax)
128
129    # Convert to numpy arrays if that is not what we are given.
130    points = np.asarray(points)
131    asarray_or_none = lambda x : None if x is None else np.asarray(x)
132    intervals = asarray_or_none(intervals)
133    lines = asarray_or_none(lines)
134    sections = asarray_or_none(sections)
135    styles = asarray_or_none(styles)
136
137    # Total number of points
138    npoint = len(points)
139
140    # Set default line values if needed
141    if lines is None:
142        lines = np.arange(npoint)
143
144    # Set default section values if needed
145    if sections is None:
146        sections = np.zeros(npoint)
147
148    # Set default style values if needed
149    if styles is None:
150        styles = np.zeros(npoint)
151
152    # The vertical space (in inches) for a section title
153    section_title_space = 0.5
154
155    # The number of sections
156    nsect = len(set(sections))
157    if section_order is not None:
158        nsect = len(set(section_order))
159
160    # The number of section titles
161    if show_section_titles is False:
162        draw_section_titles = False
163        nsect_title = 0
164    elif show_section_titles is True:
165        draw_section_titles = True
166        nsect_title = nsect
167    else:
168        draw_section_titles = nsect > 1
169        nsect_title = nsect if nsect > 1 else 0
170
171    # The total vertical space devoted to section titles.
172    section_space_total = section_title_space * nsect_title
173
174    # Add a bit of room so that points that fall at the axis limits
175    # are not cut in half.
176    ax.set_xmargin(0.02)
177    ax.set_ymargin(0.02)
178
179    if section_order is None:
180        lines0 = list(set(sections))
181        lines0.sort()
182    else:
183        lines0 = section_order
184
185    if line_order is None:
186        lines1 = list(set(lines))
187        lines1.sort()
188    else:
189        lines1 = line_order
190
191    # A map from (section,line) codes to index positions.
192    lines_map = {}
193    for i in range(npoint):
194        if section_order is not None and sections[i] not in section_order:
195            continue
196        if line_order is not None and lines[i] not in line_order:
197            continue
198        ky = (sections[i], lines[i])
199        if ky not in lines_map:
200            lines_map[ky] = []
201        lines_map[ky].append(i)
202
203    # Get the size of the axes on the parent figure in inches
204    bbox = ax.get_window_extent().transformed(
205        fig.dpi_scale_trans.inverted())
206    awidth, aheight = bbox.width, bbox.height
207
208    # The number of lines in the plot.
209    nrows = len(lines_map)
210
211    # The positions of the lowest and highest guideline in axes
212    # coordinates (for horizontal dotplots), or the leftmost and
213    # rightmost guidelines (for vertical dotplots).
214    bottom, top = 0, 1
215
216    if horizontal:
217        # x coordinate is data, y coordinate is axes
218        trans = transforms.blended_transform_factory(ax.transData,
219                                                     ax.transAxes)
220    else:
221        # x coordinate is axes, y coordinate is data
222        trans = transforms.blended_transform_factory(ax.transAxes,
223                                                     ax.transData)
224
225    # Space used for a section title, in axes coordinates
226    title_space_axes = section_title_space / aheight
227
228    # Space between lines
229    if horizontal:
230        dpos = (top - bottom - nsect_title*title_space_axes) /\
231            float(nrows)
232    else:
233        dpos = (top - bottom) / float(nrows)
234
235    # Determine the spacing for stacked points
236    if styles_order is not None:
237        style_codes = styles_order
238    else:
239        style_codes = list(set(styles))
240        style_codes.sort()
241    # Order is top to bottom for horizontal plots, so need to
242    # flip.
243    if horizontal:
244        style_codes = style_codes[::-1]
245    # nval is the maximum number of points on one line.
246    nval = len(style_codes)
247    if nval > 1:
248        stackd = dpos / (2.5*(float(nval)-1))
249    else:
250        stackd = 0.
251
252    # Map from style code to its integer position
253    style_codes_map = {x: style_codes.index(x) for x in style_codes}
254
255    # Setup default marker styles
256    colors = ["r", "g", "b", "y", "k", "purple", "orange"]
257    if marker_props is None:
258        marker_props = {x: {} for x in style_codes}
259    for j in range(nval):
260        sc = style_codes[j]
261        if "color" not in marker_props[sc]:
262            marker_props[sc]["color"] = colors[j % len(colors)]
263        if "marker" not in marker_props[sc]:
264            marker_props[sc]["marker"] = "o"
265        if "ms" not in marker_props[sc]:
266            marker_props[sc]["ms"] = 10 if stackd == 0 else 6
267
268    # Setup default line styles
269    if line_props is None:
270        line_props = {x: {} for x in style_codes}
271    for j in range(nval):
272        sc = style_codes[j]
273        if "color" not in line_props[sc]:
274            line_props[sc]["color"] = "grey"
275        if "linewidth" not in line_props[sc]:
276            line_props[sc]["linewidth"] = 2 if stackd > 0 else 8
277
278    if horizontal:
279        # The vertical position of the first line.
280        pos = top - dpos/2 if nsect == 1 else top
281    else:
282        # The horizontal position of the first line.
283        pos = bottom + dpos/2
284
285    # Points that have already been labeled
286    labeled = set()
287
288    # Positions of the y axis grid lines
289    ticks = []
290
291    # Loop through the sections
292    for k0 in lines0:
293
294        # Draw a section title
295        if draw_section_titles:
296
297            if horizontal:
298
299                y0 = pos + dpos/2 if k0 == lines0[0] else pos
300
301                ax.fill_between((0, 1), (y0,y0),
302                                (pos-0.7*title_space_axes,
303                                 pos-0.7*title_space_axes),
304                                color='darkgrey',
305                                transform=ax.transAxes,
306                                zorder=1)
307
308                txt = ax.text(0.5, pos - 0.35*title_space_axes, k0,
309                              horizontalalignment='center',
310                              verticalalignment='center',
311                              transform=ax.transAxes)
312                txt.set_fontweight("bold")
313                pos -= title_space_axes
314
315            else:
316
317                m = len([k for k in lines_map if k[0] == k0])
318
319                ax.fill_between((pos-dpos/2+0.01,
320                                 pos+(m-1)*dpos+dpos/2-0.01),
321                                (1.01,1.01), (1.06,1.06),
322                                color='darkgrey',
323                                transform=ax.transAxes,
324                                zorder=1, clip_on=False)
325
326                txt = ax.text(pos + (m-1)*dpos/2, 1.02, k0,
327                              horizontalalignment='center',
328                              verticalalignment='bottom',
329                              transform=ax.transAxes)
330                txt.set_fontweight("bold")
331
332        jrow = 0
333        for k1 in lines1:
334
335            # No data to plot
336            if (k0, k1) not in lines_map:
337                continue
338
339            # Draw the guideline
340            if horizontal:
341                ax.axhline(pos, color='grey')
342            else:
343                ax.axvline(pos, color='grey')
344
345            # Set up the labels
346            if split_names is not None:
347                us = k1.split(split_names)
348                if len(us) >= 2:
349                    left_label, right_label = us[0], us[1]
350                else:
351                    left_label, right_label = k1, None
352            else:
353                left_label, right_label = k1, None
354
355            if fmt_left_name is not None:
356                left_label = fmt_left_name(left_label)
357
358            if fmt_right_name is not None:
359                right_label = fmt_right_name(right_label)
360
361            # Draw the stripe
362            if striped and jrow % 2 == 0:
363                if horizontal:
364                    ax.fill_between((0, 1), (pos-dpos/2, pos-dpos/2),
365                                    (pos+dpos/2, pos+dpos/2),
366                                    color='lightgrey',
367                                    transform=ax.transAxes,
368                                    zorder=0)
369                else:
370                    ax.fill_between((pos-dpos/2, pos+dpos/2),
371                                    (0, 0), (1, 1),
372                                    color='lightgrey',
373                                    transform=ax.transAxes,
374                                    zorder=0)
375
376            jrow += 1
377
378            # Draw the left margin label
379            if show_names.lower() in ("left", "both"):
380                if horizontal:
381                    ax.text(-0.1/awidth, pos, left_label,
382                            horizontalalignment="right",
383                            verticalalignment='center',
384                            transform=ax.transAxes,
385                            family='monospace')
386                else:
387                    ax.text(pos, -0.1/aheight, left_label,
388                            horizontalalignment="center",
389                            verticalalignment='top',
390                            transform=ax.transAxes,
391                            family='monospace')
392
393            # Draw the right margin label
394            if show_names.lower() in ("right", "both"):
395                if right_label is not None:
396                    if horizontal:
397                        ax.text(1 + 0.1/awidth, pos, right_label,
398                                horizontalalignment="left",
399                                verticalalignment='center',
400                                transform=ax.transAxes,
401                                family='monospace')
402                    else:
403                        ax.text(pos, 1 + 0.1/aheight, right_label,
404                                horizontalalignment="center",
405                                verticalalignment='bottom',
406                                transform=ax.transAxes,
407                                family='monospace')
408
409            # Save the vertical position so that we can place the
410            # tick marks
411            ticks.append(pos)
412
413            # Loop over the points in one line
414            for ji,jp in enumerate(lines_map[(k0,k1)]):
415
416                # Calculate the vertical offset
417                yo = 0
418                if stacked:
419                    yo = -dpos/5 + style_codes_map[styles[jp]]*stackd
420
421                pt = points[jp]
422
423                # Plot the interval
424                if intervals is not None:
425
426                    # Symmetric interval
427                    if np.isscalar(intervals[jp]):
428                        lcb, ucb = pt - intervals[jp],\
429                            pt + intervals[jp]
430
431                    # Nonsymmetric interval
432                    else:
433                        lcb, ucb = pt - intervals[jp][0],\
434                            pt + intervals[jp][1]
435
436                    # Draw the interval
437                    if horizontal:
438                        ax.plot([lcb, ucb], [pos+yo, pos+yo], '-',
439                                transform=trans,
440                                **line_props[styles[jp]])
441                    else:
442                        ax.plot([pos+yo, pos+yo], [lcb, ucb], '-',
443                                transform=trans,
444                                **line_props[styles[jp]])
445
446
447                # Plot the point
448                sl = styles[jp]
449                sll = sl if sl not in labeled else None
450                labeled.add(sl)
451                if horizontal:
452                    ax.plot([pt,], [pos+yo,], ls='None',
453                            transform=trans, label=sll,
454                            **marker_props[sl])
455                else:
456                    ax.plot([pos+yo,], [pt,], ls='None',
457                            transform=trans, label=sll,
458                            **marker_props[sl])
459
460            if horizontal:
461                pos -= dpos
462            else:
463                pos += dpos
464
465    # Set up the axis
466    if horizontal:
467        ax.xaxis.set_ticks_position("bottom")
468        ax.yaxis.set_ticks_position("none")
469        ax.set_yticklabels([])
470        ax.spines['left'].set_color('none')
471        ax.spines['right'].set_color('none')
472        ax.spines['top'].set_color('none')
473        ax.spines['bottom'].set_position(('axes', -0.1/aheight))
474        ax.set_ylim(0, 1)
475        ax.yaxis.set_ticks(ticks)
476        ax.autoscale_view(scaley=False, tight=True)
477    else:
478        ax.yaxis.set_ticks_position("left")
479        ax.xaxis.set_ticks_position("none")
480        ax.set_xticklabels([])
481        ax.spines['bottom'].set_color('none')
482        ax.spines['right'].set_color('none')
483        ax.spines['top'].set_color('none')
484        ax.spines['left'].set_position(('axes', -0.1/awidth))
485        ax.set_xlim(0, 1)
486        ax.xaxis.set_ticks(ticks)
487        ax.autoscale_view(scalex=False, tight=True)
488
489    return fig
490