1"""Utility functions, mostly for internal use."""
2import os
3import re
4import inspect
5import warnings
6import colorsys
7from urllib.request import urlopen, urlretrieve
8
9import numpy as np
10from scipy import stats
11import pandas as pd
12import matplotlib as mpl
13import matplotlib.colors as mplcol
14import matplotlib.pyplot as plt
15from matplotlib.cbook import normalize_kwargs
16
17
18__all__ = ["desaturate", "saturate", "set_hls_values",
19           "despine", "get_dataset_names", "get_data_home", "load_dataset"]
20
21
22def sort_df(df, *args, **kwargs):
23    """Wrapper to handle different pandas sorting API pre/post 0.17."""
24    msg = "This function is deprecated and will be removed in a future version"
25    warnings.warn(msg)
26    try:
27        return df.sort_values(*args, **kwargs)
28    except AttributeError:
29        return df.sort(*args, **kwargs)
30
31
32def ci_to_errsize(cis, heights):
33    """Convert intervals to error arguments relative to plot heights.
34
35    Parameters
36    ----------
37    cis: 2 x n sequence
38        sequence of confidence interval limits
39    heights : n sequence
40        sequence of plot heights
41
42    Returns
43    -------
44    errsize : 2 x n array
45        sequence of error size relative to height values in correct
46        format as argument for plt.bar
47
48    """
49    cis = np.atleast_2d(cis).reshape(2, -1)
50    heights = np.atleast_1d(heights)
51    errsize = []
52    for i, (low, high) in enumerate(np.transpose(cis)):
53        h = heights[i]
54        elow = h - low
55        ehigh = high - h
56        errsize.append([elow, ehigh])
57
58    errsize = np.asarray(errsize).T
59    return errsize
60
61
62def pmf_hist(a, bins=10):
63    """Return arguments to plt.bar for pmf-like histogram of an array.
64
65    DEPRECATED: will be removed in a future version.
66
67    Parameters
68    ----------
69    a: array-like
70        array to make histogram of
71    bins: int
72        number of bins
73
74    Returns
75    -------
76    x: array
77        left x position of bars
78    h: array
79        height of bars
80    w: float
81        width of bars
82
83    """
84    msg = "This function is deprecated and will be removed in a future version"
85    warnings.warn(msg, FutureWarning)
86    n, x = np.histogram(a, bins)
87    h = n / n.sum()
88    w = x[1] - x[0]
89    return x[:-1], h, w
90
91
92def desaturate(color, prop):
93    """Decrease the saturation channel of a color by some percent.
94
95    Parameters
96    ----------
97    color : matplotlib color
98        hex, rgb-tuple, or html color name
99    prop : float
100        saturation channel of color will be multiplied by this value
101
102    Returns
103    -------
104    new_color : rgb tuple
105        desaturated color code in RGB tuple representation
106
107    """
108    # Check inputs
109    if not 0 <= prop <= 1:
110        raise ValueError("prop must be between 0 and 1")
111
112    # Get rgb tuple rep
113    rgb = mplcol.colorConverter.to_rgb(color)
114
115    # Convert to hls
116    h, l, s = colorsys.rgb_to_hls(*rgb)
117
118    # Desaturate the saturation channel
119    s *= prop
120
121    # Convert back to rgb
122    new_color = colorsys.hls_to_rgb(h, l, s)
123
124    return new_color
125
126
127def saturate(color):
128    """Return a fully saturated color with the same hue.
129
130    Parameters
131    ----------
132    color : matplotlib color
133        hex, rgb-tuple, or html color name
134
135    Returns
136    -------
137    new_color : rgb tuple
138        saturated color code in RGB tuple representation
139
140    """
141    return set_hls_values(color, s=1)
142
143
144def set_hls_values(color, h=None, l=None, s=None):  # noqa
145    """Independently manipulate the h, l, or s channels of a color.
146
147    Parameters
148    ----------
149    color : matplotlib color
150        hex, rgb-tuple, or html color name
151    h, l, s : floats between 0 and 1, or None
152        new values for each channel in hls space
153
154    Returns
155    -------
156    new_color : rgb tuple
157        new color code in RGB tuple representation
158
159    """
160    # Get an RGB tuple representation
161    rgb = mplcol.colorConverter.to_rgb(color)
162    vals = list(colorsys.rgb_to_hls(*rgb))
163    for i, val in enumerate([h, l, s]):
164        if val is not None:
165            vals[i] = val
166
167    rgb = colorsys.hls_to_rgb(*vals)
168    return rgb
169
170
171def axlabel(xlabel, ylabel, **kwargs):
172    """Grab current axis and label it.
173
174    DEPRECATED: will be removed in a future version.
175
176    """
177    msg = "This function is deprecated and will be removed in a future version"
178    warnings.warn(msg, FutureWarning)
179    ax = plt.gca()
180    ax.set_xlabel(xlabel, **kwargs)
181    ax.set_ylabel(ylabel, **kwargs)
182
183
184def remove_na(vector):
185    """Helper method for removing null values from data vectors.
186
187    Parameters
188    ----------
189    vector : vector object
190        Must implement boolean masking with [] subscript syntax.
191
192    Returns
193    -------
194    clean_clean : same type as ``vector``
195        Vector of data with null values removed. May be a copy or a view.
196
197    """
198    return vector[pd.notnull(vector)]
199
200
201def get_color_cycle():
202    """Return the list of colors in the current matplotlib color cycle
203
204    Parameters
205    ----------
206    None
207
208    Returns
209    -------
210    colors : list
211        List of matplotlib colors in the current cycle, or dark gray if
212        the current color cycle is empty.
213    """
214    cycler = mpl.rcParams['axes.prop_cycle']
215    return cycler.by_key()['color'] if 'color' in cycler.keys else [".15"]
216
217
218def despine(fig=None, ax=None, top=True, right=True, left=False,
219            bottom=False, offset=None, trim=False):
220    """Remove the top and right spines from plot(s).
221
222    fig : matplotlib figure, optional
223        Figure to despine all axes of, defaults to the current figure.
224    ax : matplotlib axes, optional
225        Specific axes object to despine. Ignored if fig is provided.
226    top, right, left, bottom : boolean, optional
227        If True, remove that spine.
228    offset : int or dict, optional
229        Absolute distance, in points, spines should be moved away
230        from the axes (negative values move spines inward). A single value
231        applies to all spines; a dict can be used to set offset values per
232        side.
233    trim : bool, optional
234        If True, limit spines to the smallest and largest major tick
235        on each non-despined axis.
236
237    Returns
238    -------
239    None
240
241    """
242    # Get references to the axes we want
243    if fig is None and ax is None:
244        axes = plt.gcf().axes
245    elif fig is not None:
246        axes = fig.axes
247    elif ax is not None:
248        axes = [ax]
249
250    for ax_i in axes:
251        for side in ["top", "right", "left", "bottom"]:
252            # Toggle the spine objects
253            is_visible = not locals()[side]
254            ax_i.spines[side].set_visible(is_visible)
255            if offset is not None and is_visible:
256                try:
257                    val = offset.get(side, 0)
258                except AttributeError:
259                    val = offset
260                ax_i.spines[side].set_position(('outward', val))
261
262        # Potentially move the ticks
263        if left and not right:
264            maj_on = any(
265                t.tick1line.get_visible()
266                for t in ax_i.yaxis.majorTicks
267            )
268            min_on = any(
269                t.tick1line.get_visible()
270                for t in ax_i.yaxis.minorTicks
271            )
272            ax_i.yaxis.set_ticks_position("right")
273            for t in ax_i.yaxis.majorTicks:
274                t.tick2line.set_visible(maj_on)
275            for t in ax_i.yaxis.minorTicks:
276                t.tick2line.set_visible(min_on)
277
278        if bottom and not top:
279            maj_on = any(
280                t.tick1line.get_visible()
281                for t in ax_i.xaxis.majorTicks
282            )
283            min_on = any(
284                t.tick1line.get_visible()
285                for t in ax_i.xaxis.minorTicks
286            )
287            ax_i.xaxis.set_ticks_position("top")
288            for t in ax_i.xaxis.majorTicks:
289                t.tick2line.set_visible(maj_on)
290            for t in ax_i.xaxis.minorTicks:
291                t.tick2line.set_visible(min_on)
292
293        if trim:
294            # clip off the parts of the spines that extend past major ticks
295            xticks = np.asarray(ax_i.get_xticks())
296            if xticks.size:
297                firsttick = np.compress(xticks >= min(ax_i.get_xlim()),
298                                        xticks)[0]
299                lasttick = np.compress(xticks <= max(ax_i.get_xlim()),
300                                       xticks)[-1]
301                ax_i.spines['bottom'].set_bounds(firsttick, lasttick)
302                ax_i.spines['top'].set_bounds(firsttick, lasttick)
303                newticks = xticks.compress(xticks <= lasttick)
304                newticks = newticks.compress(newticks >= firsttick)
305                ax_i.set_xticks(newticks)
306
307            yticks = np.asarray(ax_i.get_yticks())
308            if yticks.size:
309                firsttick = np.compress(yticks >= min(ax_i.get_ylim()),
310                                        yticks)[0]
311                lasttick = np.compress(yticks <= max(ax_i.get_ylim()),
312                                       yticks)[-1]
313                ax_i.spines['left'].set_bounds(firsttick, lasttick)
314                ax_i.spines['right'].set_bounds(firsttick, lasttick)
315                newticks = yticks.compress(yticks <= lasttick)
316                newticks = newticks.compress(newticks >= firsttick)
317                ax_i.set_yticks(newticks)
318
319
320def _kde_support(data, bw, gridsize, cut, clip):
321    """Establish support for a kernel density estimate."""
322    support_min = max(data.min() - bw * cut, clip[0])
323    support_max = min(data.max() + bw * cut, clip[1])
324    support = np.linspace(support_min, support_max, gridsize)
325
326    return support
327
328
329def percentiles(a, pcts, axis=None):
330    """Like scoreatpercentile but can take and return array of percentiles.
331
332    DEPRECATED: will be removed in a future version.
333
334    Parameters
335    ----------
336    a : array
337        data
338    pcts : sequence of percentile values
339        percentile or percentiles to find score at
340    axis : int or None
341        if not None, computes scores over this axis
342
343    Returns
344    -------
345    scores: array
346        array of scores at requested percentiles
347        first dimension is length of object passed to ``pcts``
348
349    """
350    msg = "This function is deprecated and will be removed in a future version"
351    warnings.warn(msg, FutureWarning)
352
353    scores = []
354    try:
355        n = len(pcts)
356    except TypeError:
357        pcts = [pcts]
358        n = 0
359    for i, p in enumerate(pcts):
360        if axis is None:
361            score = stats.scoreatpercentile(a.ravel(), p)
362        else:
363            score = np.apply_along_axis(stats.scoreatpercentile, axis, a, p)
364        scores.append(score)
365    scores = np.asarray(scores)
366    if not n:
367        scores = scores.squeeze()
368    return scores
369
370
371def ci(a, which=95, axis=None):
372    """Return a percentile range from an array of values."""
373    p = 50 - which / 2, 50 + which / 2
374    return np.percentile(a, p, axis)
375
376
377def sig_stars(p):
378    """Return a R-style significance string corresponding to p values.
379
380    DEPRECATED: will be removed in a future version.
381
382    """
383    msg = "This function is deprecated and will be removed in a future version"
384    warnings.warn(msg, FutureWarning)
385
386    if p < 0.001:
387        return "***"
388    elif p < 0.01:
389        return "**"
390    elif p < 0.05:
391        return "*"
392    elif p < 0.1:
393        return "."
394    return ""
395
396
397def iqr(a):
398    """Calculate the IQR for an array of numbers.
399
400    DEPRECATED: will be removed in a future version.
401
402    """
403    msg = "This function is deprecated and will be removed in a future version"
404    warnings.warn(msg, FutureWarning)
405
406    a = np.asarray(a)
407    q1 = stats.scoreatpercentile(a, 25)
408    q3 = stats.scoreatpercentile(a, 75)
409    return q3 - q1
410
411
412def get_dataset_names():
413    """Report available example datasets, useful for reporting issues.
414
415    Requires an internet connection.
416
417    """
418    url = "https://github.com/mwaskom/seaborn-data"
419    with urlopen(url) as resp:
420        html = resp.read()
421
422    pat = r"/mwaskom/seaborn-data/blob/master/(\w*).csv"
423    datasets = re.findall(pat, html.decode())
424    return datasets
425
426
427def get_data_home(data_home=None):
428    """Return a path to the cache directory for example datasets.
429
430    This directory is then used by :func:`load_dataset`.
431
432    If the ``data_home`` argument is not specified, it tries to read from the
433    ``SEABORN_DATA`` environment variable and defaults to ``~/seaborn-data``.
434
435    """
436    if data_home is None:
437        data_home = os.environ.get('SEABORN_DATA',
438                                   os.path.join('~', 'seaborn-data'))
439    data_home = os.path.expanduser(data_home)
440    if not os.path.exists(data_home):
441        os.makedirs(data_home)
442    return data_home
443
444
445def load_dataset(name, cache=True, data_home=None, **kws):
446    """Load an example dataset from the online repository (requires internet).
447
448    This function provides quick access to a small number of example datasets
449    that are useful for documenting seaborn or generating reproducible examples
450    for bug reports. It is not necessary for normal usage.
451
452    Note that some of the datasets have a small amount of preprocessing applied
453    to define a proper ordering for categorical variables.
454
455    Use :func:`get_dataset_names` to see a list of available datasets.
456
457    Parameters
458    ----------
459    name : str
460        Name of the dataset (``{name}.csv`` on
461        https://github.com/mwaskom/seaborn-data).
462    cache : boolean, optional
463        If True, try to load from the local cache first, and save to the cache
464        if a download is required.
465    data_home : string, optional
466        The directory in which to cache data; see :func:`get_data_home`.
467    kws : keys and values, optional
468        Additional keyword arguments are passed to passed through to
469        :func:`pandas.read_csv`.
470
471    Returns
472    -------
473    df : :class:`pandas.DataFrame`
474        Tabular data, possibly with some preprocessing applied.
475
476    """
477    path = ("https://raw.githubusercontent.com/"
478            "mwaskom/seaborn-data/master/{}.csv")
479    full_path = path.format(name)
480
481    if cache:
482        cache_path = os.path.join(get_data_home(data_home),
483                                  os.path.basename(full_path))
484        if not os.path.exists(cache_path):
485            if name not in get_dataset_names():
486                raise ValueError(f"'{name}' is not one of the example datasets.")
487            urlretrieve(full_path, cache_path)
488        full_path = cache_path
489
490    df = pd.read_csv(full_path, **kws)
491
492    if df.iloc[-1].isnull().all():
493        df = df.iloc[:-1]
494
495    # Set some columns as a categorical type with ordered levels
496
497    if name == "tips":
498        df["day"] = pd.Categorical(df["day"], ["Thur", "Fri", "Sat", "Sun"])
499        df["sex"] = pd.Categorical(df["sex"], ["Male", "Female"])
500        df["time"] = pd.Categorical(df["time"], ["Lunch", "Dinner"])
501        df["smoker"] = pd.Categorical(df["smoker"], ["Yes", "No"])
502
503    if name == "flights":
504        months = df["month"].str[:3]
505        df["month"] = pd.Categorical(months, months.unique())
506
507    if name == "exercise":
508        df["time"] = pd.Categorical(df["time"], ["1 min", "15 min", "30 min"])
509        df["kind"] = pd.Categorical(df["kind"], ["rest", "walking", "running"])
510        df["diet"] = pd.Categorical(df["diet"], ["no fat", "low fat"])
511
512    if name == "titanic":
513        df["class"] = pd.Categorical(df["class"], ["First", "Second", "Third"])
514        df["deck"] = pd.Categorical(df["deck"], list("ABCDEFG"))
515
516    if name == "penguins":
517        df["sex"] = df["sex"].str.title()
518
519    if name == "diamonds":
520        df["color"] = pd.Categorical(
521            df["color"], ["D", "E", "F", "G", "H", "I", "J"],
522        )
523        df["clarity"] = pd.Categorical(
524            df["clarity"], ["IF", "VVS1", "VVS2", "VS1", "VS2", "SI1", "SI2", "I1"],
525        )
526        df["cut"] = pd.Categorical(
527            df["cut"], ["Ideal", "Premium", "Very Good", "Good", "Fair"],
528        )
529
530    return df
531
532
533def axis_ticklabels_overlap(labels):
534    """Return a boolean for whether the list of ticklabels have overlaps.
535
536    Parameters
537    ----------
538    labels : list of matplotlib ticklabels
539
540    Returns
541    -------
542    overlap : boolean
543        True if any of the labels overlap.
544
545    """
546    if not labels:
547        return False
548    try:
549        bboxes = [l.get_window_extent() for l in labels]
550        overlaps = [b.count_overlaps(bboxes) for b in bboxes]
551        return max(overlaps) > 1
552    except RuntimeError:
553        # Issue on macos backend raises an error in the above code
554        return False
555
556
557def axes_ticklabels_overlap(ax):
558    """Return booleans for whether the x and y ticklabels on an Axes overlap.
559
560    Parameters
561    ----------
562    ax : matplotlib Axes
563
564    Returns
565    -------
566    x_overlap, y_overlap : booleans
567        True when the labels on that axis overlap.
568
569    """
570    return (axis_ticklabels_overlap(ax.get_xticklabels()),
571            axis_ticklabels_overlap(ax.get_yticklabels()))
572
573
574def locator_to_legend_entries(locator, limits, dtype):
575    """Return levels and formatted levels for brief numeric legends."""
576    raw_levels = locator.tick_values(*limits).astype(dtype)
577
578    # The locator can return ticks outside the limits, clip them here
579    raw_levels = [l for l in raw_levels if l >= limits[0] and l <= limits[1]]
580
581    class dummy_axis:
582        def get_view_interval(self):
583            return limits
584
585    if isinstance(locator, mpl.ticker.LogLocator):
586        formatter = mpl.ticker.LogFormatter()
587    else:
588        formatter = mpl.ticker.ScalarFormatter()
589    formatter.axis = dummy_axis()
590
591    # TODO: The following two lines should be replaced
592    # once pinned matplotlib>=3.1.0 with:
593    # formatted_levels = formatter.format_ticks(raw_levels)
594    formatter.set_locs(raw_levels)
595    formatted_levels = [formatter(x) for x in raw_levels]
596
597    return raw_levels, formatted_levels
598
599
600def relative_luminance(color):
601    """Calculate the relative luminance of a color according to W3C standards
602
603    Parameters
604    ----------
605    color : matplotlib color or sequence of matplotlib colors
606        Hex code, rgb-tuple, or html color name.
607
608    Returns
609    -------
610    luminance : float(s) between 0 and 1
611
612    """
613    rgb = mpl.colors.colorConverter.to_rgba_array(color)[:, :3]
614    rgb = np.where(rgb <= .03928, rgb / 12.92, ((rgb + .055) / 1.055) ** 2.4)
615    lum = rgb.dot([.2126, .7152, .0722])
616    try:
617        return lum.item()
618    except ValueError:
619        return lum
620
621
622def to_utf8(obj):
623    """Return a string representing a Python object.
624
625    Strings (i.e. type ``str``) are returned unchanged.
626
627    Byte strings (i.e. type ``bytes``) are returned as UTF-8-decoded strings.
628
629    For other objects, the method ``__str__()`` is called, and the result is
630    returned as a string.
631
632    Parameters
633    ----------
634    obj : object
635        Any Python object
636
637    Returns
638    -------
639    s : str
640        UTF-8-decoded string representation of ``obj``
641
642    """
643    if isinstance(obj, str):
644        return obj
645    try:
646        return obj.decode(encoding="utf-8")
647    except AttributeError:  # obj is not bytes-like
648        return str(obj)
649
650
651def _normalize_kwargs(kws, artist):
652    """Wrapper for mpl.cbook.normalize_kwargs that supports <= 3.2.1."""
653    _alias_map = {
654        'color': ['c'],
655        'linewidth': ['lw'],
656        'linestyle': ['ls'],
657        'facecolor': ['fc'],
658        'edgecolor': ['ec'],
659        'markerfacecolor': ['mfc'],
660        'markeredgecolor': ['mec'],
661        'markeredgewidth': ['mew'],
662        'markersize': ['ms']
663    }
664    try:
665        kws = normalize_kwargs(kws, artist)
666    except AttributeError:
667        kws = normalize_kwargs(kws, _alias_map)
668    return kws
669
670
671def _check_argument(param, options, value):
672    """Raise if value for param is not in options."""
673    if value not in options:
674        raise ValueError(
675            f"`{param}` must be one of {options}, but {value} was passed.`"
676        )
677
678
679def _assign_default_kwargs(kws, call_func, source_func):
680    """Assign default kwargs for call_func using values from source_func."""
681    # This exists so that axes-level functions and figure-level functions can
682    # both call a Plotter method while having the default kwargs be defined in
683    # the signature of the axes-level function.
684    # An alternative would be to  have a decorator on the method that sets its
685    # defaults based on those defined in the axes-level function.
686    # Then the figure-level function would not need to worry about defaults.
687    # I am not sure which is better.
688    needed = inspect.signature(call_func).parameters
689    defaults = inspect.signature(source_func).parameters
690
691    for param in needed:
692        if param in defaults and param not in kws:
693            kws[param] = defaults[param].default
694
695    return kws
696
697
698def adjust_legend_subtitles(legend):
699    """Make invisible-handle "subtitles" entries look more like titles."""
700    # Legend title not in rcParams until 3.0
701    font_size = plt.rcParams.get("legend.title_fontsize", None)
702    hpackers = legend.findobj(mpl.offsetbox.VPacker)[0].get_children()
703    for hpack in hpackers:
704        draw_area, text_area = hpack.get_children()
705        handles = draw_area.get_children()
706        if not all(artist.get_visible() for artist in handles):
707            draw_area.set_width(0)
708            for text in text_area.get_children():
709                if font_size is not None:
710                    text.set_size(font_size)
711