1"""Utility functions, mostly for internal use.""" 2import os 3import re 4import inspect 5import warnings 6import colorsys 7from urllib.request import urlopen, urlretrieve 8 9import numpy as np 10from scipy import stats 11import pandas as pd 12import matplotlib as mpl 13import matplotlib.colors as mplcol 14import matplotlib.pyplot as plt 15from matplotlib.cbook import normalize_kwargs 16 17 18__all__ = ["desaturate", "saturate", "set_hls_values", 19 "despine", "get_dataset_names", "get_data_home", "load_dataset"] 20 21 22def sort_df(df, *args, **kwargs): 23 """Wrapper to handle different pandas sorting API pre/post 0.17.""" 24 msg = "This function is deprecated and will be removed in a future version" 25 warnings.warn(msg) 26 try: 27 return df.sort_values(*args, **kwargs) 28 except AttributeError: 29 return df.sort(*args, **kwargs) 30 31 32def ci_to_errsize(cis, heights): 33 """Convert intervals to error arguments relative to plot heights. 34 35 Parameters 36 ---------- 37 cis: 2 x n sequence 38 sequence of confidence interval limits 39 heights : n sequence 40 sequence of plot heights 41 42 Returns 43 ------- 44 errsize : 2 x n array 45 sequence of error size relative to height values in correct 46 format as argument for plt.bar 47 48 """ 49 cis = np.atleast_2d(cis).reshape(2, -1) 50 heights = np.atleast_1d(heights) 51 errsize = [] 52 for i, (low, high) in enumerate(np.transpose(cis)): 53 h = heights[i] 54 elow = h - low 55 ehigh = high - h 56 errsize.append([elow, ehigh]) 57 58 errsize = np.asarray(errsize).T 59 return errsize 60 61 62def pmf_hist(a, bins=10): 63 """Return arguments to plt.bar for pmf-like histogram of an array. 64 65 DEPRECATED: will be removed in a future version. 66 67 Parameters 68 ---------- 69 a: array-like 70 array to make histogram of 71 bins: int 72 number of bins 73 74 Returns 75 ------- 76 x: array 77 left x position of bars 78 h: array 79 height of bars 80 w: float 81 width of bars 82 83 """ 84 msg = "This function is deprecated and will be removed in a future version" 85 warnings.warn(msg, FutureWarning) 86 n, x = np.histogram(a, bins) 87 h = n / n.sum() 88 w = x[1] - x[0] 89 return x[:-1], h, w 90 91 92def desaturate(color, prop): 93 """Decrease the saturation channel of a color by some percent. 94 95 Parameters 96 ---------- 97 color : matplotlib color 98 hex, rgb-tuple, or html color name 99 prop : float 100 saturation channel of color will be multiplied by this value 101 102 Returns 103 ------- 104 new_color : rgb tuple 105 desaturated color code in RGB tuple representation 106 107 """ 108 # Check inputs 109 if not 0 <= prop <= 1: 110 raise ValueError("prop must be between 0 and 1") 111 112 # Get rgb tuple rep 113 rgb = mplcol.colorConverter.to_rgb(color) 114 115 # Convert to hls 116 h, l, s = colorsys.rgb_to_hls(*rgb) 117 118 # Desaturate the saturation channel 119 s *= prop 120 121 # Convert back to rgb 122 new_color = colorsys.hls_to_rgb(h, l, s) 123 124 return new_color 125 126 127def saturate(color): 128 """Return a fully saturated color with the same hue. 129 130 Parameters 131 ---------- 132 color : matplotlib color 133 hex, rgb-tuple, or html color name 134 135 Returns 136 ------- 137 new_color : rgb tuple 138 saturated color code in RGB tuple representation 139 140 """ 141 return set_hls_values(color, s=1) 142 143 144def set_hls_values(color, h=None, l=None, s=None): # noqa 145 """Independently manipulate the h, l, or s channels of a color. 146 147 Parameters 148 ---------- 149 color : matplotlib color 150 hex, rgb-tuple, or html color name 151 h, l, s : floats between 0 and 1, or None 152 new values for each channel in hls space 153 154 Returns 155 ------- 156 new_color : rgb tuple 157 new color code in RGB tuple representation 158 159 """ 160 # Get an RGB tuple representation 161 rgb = mplcol.colorConverter.to_rgb(color) 162 vals = list(colorsys.rgb_to_hls(*rgb)) 163 for i, val in enumerate([h, l, s]): 164 if val is not None: 165 vals[i] = val 166 167 rgb = colorsys.hls_to_rgb(*vals) 168 return rgb 169 170 171def axlabel(xlabel, ylabel, **kwargs): 172 """Grab current axis and label it. 173 174 DEPRECATED: will be removed in a future version. 175 176 """ 177 msg = "This function is deprecated and will be removed in a future version" 178 warnings.warn(msg, FutureWarning) 179 ax = plt.gca() 180 ax.set_xlabel(xlabel, **kwargs) 181 ax.set_ylabel(ylabel, **kwargs) 182 183 184def remove_na(vector): 185 """Helper method for removing null values from data vectors. 186 187 Parameters 188 ---------- 189 vector : vector object 190 Must implement boolean masking with [] subscript syntax. 191 192 Returns 193 ------- 194 clean_clean : same type as ``vector`` 195 Vector of data with null values removed. May be a copy or a view. 196 197 """ 198 return vector[pd.notnull(vector)] 199 200 201def get_color_cycle(): 202 """Return the list of colors in the current matplotlib color cycle 203 204 Parameters 205 ---------- 206 None 207 208 Returns 209 ------- 210 colors : list 211 List of matplotlib colors in the current cycle, or dark gray if 212 the current color cycle is empty. 213 """ 214 cycler = mpl.rcParams['axes.prop_cycle'] 215 return cycler.by_key()['color'] if 'color' in cycler.keys else [".15"] 216 217 218def despine(fig=None, ax=None, top=True, right=True, left=False, 219 bottom=False, offset=None, trim=False): 220 """Remove the top and right spines from plot(s). 221 222 fig : matplotlib figure, optional 223 Figure to despine all axes of, defaults to the current figure. 224 ax : matplotlib axes, optional 225 Specific axes object to despine. Ignored if fig is provided. 226 top, right, left, bottom : boolean, optional 227 If True, remove that spine. 228 offset : int or dict, optional 229 Absolute distance, in points, spines should be moved away 230 from the axes (negative values move spines inward). A single value 231 applies to all spines; a dict can be used to set offset values per 232 side. 233 trim : bool, optional 234 If True, limit spines to the smallest and largest major tick 235 on each non-despined axis. 236 237 Returns 238 ------- 239 None 240 241 """ 242 # Get references to the axes we want 243 if fig is None and ax is None: 244 axes = plt.gcf().axes 245 elif fig is not None: 246 axes = fig.axes 247 elif ax is not None: 248 axes = [ax] 249 250 for ax_i in axes: 251 for side in ["top", "right", "left", "bottom"]: 252 # Toggle the spine objects 253 is_visible = not locals()[side] 254 ax_i.spines[side].set_visible(is_visible) 255 if offset is not None and is_visible: 256 try: 257 val = offset.get(side, 0) 258 except AttributeError: 259 val = offset 260 ax_i.spines[side].set_position(('outward', val)) 261 262 # Potentially move the ticks 263 if left and not right: 264 maj_on = any( 265 t.tick1line.get_visible() 266 for t in ax_i.yaxis.majorTicks 267 ) 268 min_on = any( 269 t.tick1line.get_visible() 270 for t in ax_i.yaxis.minorTicks 271 ) 272 ax_i.yaxis.set_ticks_position("right") 273 for t in ax_i.yaxis.majorTicks: 274 t.tick2line.set_visible(maj_on) 275 for t in ax_i.yaxis.minorTicks: 276 t.tick2line.set_visible(min_on) 277 278 if bottom and not top: 279 maj_on = any( 280 t.tick1line.get_visible() 281 for t in ax_i.xaxis.majorTicks 282 ) 283 min_on = any( 284 t.tick1line.get_visible() 285 for t in ax_i.xaxis.minorTicks 286 ) 287 ax_i.xaxis.set_ticks_position("top") 288 for t in ax_i.xaxis.majorTicks: 289 t.tick2line.set_visible(maj_on) 290 for t in ax_i.xaxis.minorTicks: 291 t.tick2line.set_visible(min_on) 292 293 if trim: 294 # clip off the parts of the spines that extend past major ticks 295 xticks = np.asarray(ax_i.get_xticks()) 296 if xticks.size: 297 firsttick = np.compress(xticks >= min(ax_i.get_xlim()), 298 xticks)[0] 299 lasttick = np.compress(xticks <= max(ax_i.get_xlim()), 300 xticks)[-1] 301 ax_i.spines['bottom'].set_bounds(firsttick, lasttick) 302 ax_i.spines['top'].set_bounds(firsttick, lasttick) 303 newticks = xticks.compress(xticks <= lasttick) 304 newticks = newticks.compress(newticks >= firsttick) 305 ax_i.set_xticks(newticks) 306 307 yticks = np.asarray(ax_i.get_yticks()) 308 if yticks.size: 309 firsttick = np.compress(yticks >= min(ax_i.get_ylim()), 310 yticks)[0] 311 lasttick = np.compress(yticks <= max(ax_i.get_ylim()), 312 yticks)[-1] 313 ax_i.spines['left'].set_bounds(firsttick, lasttick) 314 ax_i.spines['right'].set_bounds(firsttick, lasttick) 315 newticks = yticks.compress(yticks <= lasttick) 316 newticks = newticks.compress(newticks >= firsttick) 317 ax_i.set_yticks(newticks) 318 319 320def _kde_support(data, bw, gridsize, cut, clip): 321 """Establish support for a kernel density estimate.""" 322 support_min = max(data.min() - bw * cut, clip[0]) 323 support_max = min(data.max() + bw * cut, clip[1]) 324 support = np.linspace(support_min, support_max, gridsize) 325 326 return support 327 328 329def percentiles(a, pcts, axis=None): 330 """Like scoreatpercentile but can take and return array of percentiles. 331 332 DEPRECATED: will be removed in a future version. 333 334 Parameters 335 ---------- 336 a : array 337 data 338 pcts : sequence of percentile values 339 percentile or percentiles to find score at 340 axis : int or None 341 if not None, computes scores over this axis 342 343 Returns 344 ------- 345 scores: array 346 array of scores at requested percentiles 347 first dimension is length of object passed to ``pcts`` 348 349 """ 350 msg = "This function is deprecated and will be removed in a future version" 351 warnings.warn(msg, FutureWarning) 352 353 scores = [] 354 try: 355 n = len(pcts) 356 except TypeError: 357 pcts = [pcts] 358 n = 0 359 for i, p in enumerate(pcts): 360 if axis is None: 361 score = stats.scoreatpercentile(a.ravel(), p) 362 else: 363 score = np.apply_along_axis(stats.scoreatpercentile, axis, a, p) 364 scores.append(score) 365 scores = np.asarray(scores) 366 if not n: 367 scores = scores.squeeze() 368 return scores 369 370 371def ci(a, which=95, axis=None): 372 """Return a percentile range from an array of values.""" 373 p = 50 - which / 2, 50 + which / 2 374 return np.percentile(a, p, axis) 375 376 377def sig_stars(p): 378 """Return a R-style significance string corresponding to p values. 379 380 DEPRECATED: will be removed in a future version. 381 382 """ 383 msg = "This function is deprecated and will be removed in a future version" 384 warnings.warn(msg, FutureWarning) 385 386 if p < 0.001: 387 return "***" 388 elif p < 0.01: 389 return "**" 390 elif p < 0.05: 391 return "*" 392 elif p < 0.1: 393 return "." 394 return "" 395 396 397def iqr(a): 398 """Calculate the IQR for an array of numbers. 399 400 DEPRECATED: will be removed in a future version. 401 402 """ 403 msg = "This function is deprecated and will be removed in a future version" 404 warnings.warn(msg, FutureWarning) 405 406 a = np.asarray(a) 407 q1 = stats.scoreatpercentile(a, 25) 408 q3 = stats.scoreatpercentile(a, 75) 409 return q3 - q1 410 411 412def get_dataset_names(): 413 """Report available example datasets, useful for reporting issues. 414 415 Requires an internet connection. 416 417 """ 418 url = "https://github.com/mwaskom/seaborn-data" 419 with urlopen(url) as resp: 420 html = resp.read() 421 422 pat = r"/mwaskom/seaborn-data/blob/master/(\w*).csv" 423 datasets = re.findall(pat, html.decode()) 424 return datasets 425 426 427def get_data_home(data_home=None): 428 """Return a path to the cache directory for example datasets. 429 430 This directory is then used by :func:`load_dataset`. 431 432 If the ``data_home`` argument is not specified, it tries to read from the 433 ``SEABORN_DATA`` environment variable and defaults to ``~/seaborn-data``. 434 435 """ 436 if data_home is None: 437 data_home = os.environ.get('SEABORN_DATA', 438 os.path.join('~', 'seaborn-data')) 439 data_home = os.path.expanduser(data_home) 440 if not os.path.exists(data_home): 441 os.makedirs(data_home) 442 return data_home 443 444 445def load_dataset(name, cache=True, data_home=None, **kws): 446 """Load an example dataset from the online repository (requires internet). 447 448 This function provides quick access to a small number of example datasets 449 that are useful for documenting seaborn or generating reproducible examples 450 for bug reports. It is not necessary for normal usage. 451 452 Note that some of the datasets have a small amount of preprocessing applied 453 to define a proper ordering for categorical variables. 454 455 Use :func:`get_dataset_names` to see a list of available datasets. 456 457 Parameters 458 ---------- 459 name : str 460 Name of the dataset (``{name}.csv`` on 461 https://github.com/mwaskom/seaborn-data). 462 cache : boolean, optional 463 If True, try to load from the local cache first, and save to the cache 464 if a download is required. 465 data_home : string, optional 466 The directory in which to cache data; see :func:`get_data_home`. 467 kws : keys and values, optional 468 Additional keyword arguments are passed to passed through to 469 :func:`pandas.read_csv`. 470 471 Returns 472 ------- 473 df : :class:`pandas.DataFrame` 474 Tabular data, possibly with some preprocessing applied. 475 476 """ 477 path = ("https://raw.githubusercontent.com/" 478 "mwaskom/seaborn-data/master/{}.csv") 479 full_path = path.format(name) 480 481 if cache: 482 cache_path = os.path.join(get_data_home(data_home), 483 os.path.basename(full_path)) 484 if not os.path.exists(cache_path): 485 if name not in get_dataset_names(): 486 raise ValueError(f"'{name}' is not one of the example datasets.") 487 urlretrieve(full_path, cache_path) 488 full_path = cache_path 489 490 df = pd.read_csv(full_path, **kws) 491 492 if df.iloc[-1].isnull().all(): 493 df = df.iloc[:-1] 494 495 # Set some columns as a categorical type with ordered levels 496 497 if name == "tips": 498 df["day"] = pd.Categorical(df["day"], ["Thur", "Fri", "Sat", "Sun"]) 499 df["sex"] = pd.Categorical(df["sex"], ["Male", "Female"]) 500 df["time"] = pd.Categorical(df["time"], ["Lunch", "Dinner"]) 501 df["smoker"] = pd.Categorical(df["smoker"], ["Yes", "No"]) 502 503 if name == "flights": 504 months = df["month"].str[:3] 505 df["month"] = pd.Categorical(months, months.unique()) 506 507 if name == "exercise": 508 df["time"] = pd.Categorical(df["time"], ["1 min", "15 min", "30 min"]) 509 df["kind"] = pd.Categorical(df["kind"], ["rest", "walking", "running"]) 510 df["diet"] = pd.Categorical(df["diet"], ["no fat", "low fat"]) 511 512 if name == "titanic": 513 df["class"] = pd.Categorical(df["class"], ["First", "Second", "Third"]) 514 df["deck"] = pd.Categorical(df["deck"], list("ABCDEFG")) 515 516 if name == "penguins": 517 df["sex"] = df["sex"].str.title() 518 519 if name == "diamonds": 520 df["color"] = pd.Categorical( 521 df["color"], ["D", "E", "F", "G", "H", "I", "J"], 522 ) 523 df["clarity"] = pd.Categorical( 524 df["clarity"], ["IF", "VVS1", "VVS2", "VS1", "VS2", "SI1", "SI2", "I1"], 525 ) 526 df["cut"] = pd.Categorical( 527 df["cut"], ["Ideal", "Premium", "Very Good", "Good", "Fair"], 528 ) 529 530 return df 531 532 533def axis_ticklabels_overlap(labels): 534 """Return a boolean for whether the list of ticklabels have overlaps. 535 536 Parameters 537 ---------- 538 labels : list of matplotlib ticklabels 539 540 Returns 541 ------- 542 overlap : boolean 543 True if any of the labels overlap. 544 545 """ 546 if not labels: 547 return False 548 try: 549 bboxes = [l.get_window_extent() for l in labels] 550 overlaps = [b.count_overlaps(bboxes) for b in bboxes] 551 return max(overlaps) > 1 552 except RuntimeError: 553 # Issue on macos backend raises an error in the above code 554 return False 555 556 557def axes_ticklabels_overlap(ax): 558 """Return booleans for whether the x and y ticklabels on an Axes overlap. 559 560 Parameters 561 ---------- 562 ax : matplotlib Axes 563 564 Returns 565 ------- 566 x_overlap, y_overlap : booleans 567 True when the labels on that axis overlap. 568 569 """ 570 return (axis_ticklabels_overlap(ax.get_xticklabels()), 571 axis_ticklabels_overlap(ax.get_yticklabels())) 572 573 574def locator_to_legend_entries(locator, limits, dtype): 575 """Return levels and formatted levels for brief numeric legends.""" 576 raw_levels = locator.tick_values(*limits).astype(dtype) 577 578 # The locator can return ticks outside the limits, clip them here 579 raw_levels = [l for l in raw_levels if l >= limits[0] and l <= limits[1]] 580 581 class dummy_axis: 582 def get_view_interval(self): 583 return limits 584 585 if isinstance(locator, mpl.ticker.LogLocator): 586 formatter = mpl.ticker.LogFormatter() 587 else: 588 formatter = mpl.ticker.ScalarFormatter() 589 formatter.axis = dummy_axis() 590 591 # TODO: The following two lines should be replaced 592 # once pinned matplotlib>=3.1.0 with: 593 # formatted_levels = formatter.format_ticks(raw_levels) 594 formatter.set_locs(raw_levels) 595 formatted_levels = [formatter(x) for x in raw_levels] 596 597 return raw_levels, formatted_levels 598 599 600def relative_luminance(color): 601 """Calculate the relative luminance of a color according to W3C standards 602 603 Parameters 604 ---------- 605 color : matplotlib color or sequence of matplotlib colors 606 Hex code, rgb-tuple, or html color name. 607 608 Returns 609 ------- 610 luminance : float(s) between 0 and 1 611 612 """ 613 rgb = mpl.colors.colorConverter.to_rgba_array(color)[:, :3] 614 rgb = np.where(rgb <= .03928, rgb / 12.92, ((rgb + .055) / 1.055) ** 2.4) 615 lum = rgb.dot([.2126, .7152, .0722]) 616 try: 617 return lum.item() 618 except ValueError: 619 return lum 620 621 622def to_utf8(obj): 623 """Return a string representing a Python object. 624 625 Strings (i.e. type ``str``) are returned unchanged. 626 627 Byte strings (i.e. type ``bytes``) are returned as UTF-8-decoded strings. 628 629 For other objects, the method ``__str__()`` is called, and the result is 630 returned as a string. 631 632 Parameters 633 ---------- 634 obj : object 635 Any Python object 636 637 Returns 638 ------- 639 s : str 640 UTF-8-decoded string representation of ``obj`` 641 642 """ 643 if isinstance(obj, str): 644 return obj 645 try: 646 return obj.decode(encoding="utf-8") 647 except AttributeError: # obj is not bytes-like 648 return str(obj) 649 650 651def _normalize_kwargs(kws, artist): 652 """Wrapper for mpl.cbook.normalize_kwargs that supports <= 3.2.1.""" 653 _alias_map = { 654 'color': ['c'], 655 'linewidth': ['lw'], 656 'linestyle': ['ls'], 657 'facecolor': ['fc'], 658 'edgecolor': ['ec'], 659 'markerfacecolor': ['mfc'], 660 'markeredgecolor': ['mec'], 661 'markeredgewidth': ['mew'], 662 'markersize': ['ms'] 663 } 664 try: 665 kws = normalize_kwargs(kws, artist) 666 except AttributeError: 667 kws = normalize_kwargs(kws, _alias_map) 668 return kws 669 670 671def _check_argument(param, options, value): 672 """Raise if value for param is not in options.""" 673 if value not in options: 674 raise ValueError( 675 f"`{param}` must be one of {options}, but {value} was passed.`" 676 ) 677 678 679def _assign_default_kwargs(kws, call_func, source_func): 680 """Assign default kwargs for call_func using values from source_func.""" 681 # This exists so that axes-level functions and figure-level functions can 682 # both call a Plotter method while having the default kwargs be defined in 683 # the signature of the axes-level function. 684 # An alternative would be to have a decorator on the method that sets its 685 # defaults based on those defined in the axes-level function. 686 # Then the figure-level function would not need to worry about defaults. 687 # I am not sure which is better. 688 needed = inspect.signature(call_func).parameters 689 defaults = inspect.signature(source_func).parameters 690 691 for param in needed: 692 if param in defaults and param not in kws: 693 kws[param] = defaults[param].default 694 695 return kws 696 697 698def adjust_legend_subtitles(legend): 699 """Make invisible-handle "subtitles" entries look more like titles.""" 700 # Legend title not in rcParams until 3.0 701 font_size = plt.rcParams.get("legend.title_fontsize", None) 702 hpackers = legend.findobj(mpl.offsetbox.VPacker)[0].get_children() 703 for hpack in hpackers: 704 draw_area, text_area = hpack.get_children() 705 handles = draw_area.get_children() 706 if not all(artist.get_visible() for artist in handles): 707 draw_area.set_width(0) 708 for text in text_area.get_children(): 709 if font_size is not None: 710 text.set_size(font_size) 711