1import contextlib
2
3
4def check_args(x, y, rope=0, prior=1, nsamples=50000):
5    if x.ndim != 1:
6        raise ValueError("'x' must be a 1-dimensional array")
7    if y.ndim != 1:
8        raise ValueError("'y' must be a 1-dimensional array")
9    if len(x) != len(y):
10        raise ValueError("'x' and 'y' must be of same length")
11    if rope < 0:
12        raise ValueError('Rope width cannot be negative')
13    if prior < 0:
14        raise ValueError('Prior strength cannot be negative')
15    if not round(nsamples) == nsamples > 0:
16        raise ValueError('Number of samples must be a positive integer')
17
18
19def call_shortcut(test, x, y, rope, *args, plot=False, names=None, **kwargs):
20    sample = test(x, y, rope, *args, **kwargs)
21    if plot:
22        return sample.probs(), sample.plot(names)
23    else:
24        return sample.probs()
25
26
27@contextlib.contextmanager
28def seaborn_plt():
29    # Set a Seaborn-like style. See https://github.com/mwaskom/seaborn.
30    try:
31        import matplotlib as mpl
32    except ImportError:
33        raise ImportError("Plotting requires 'matplotlib'; "
34                          "use 'pip install matplotlib' to install it")
35
36    params = {
37        "font.size": 12,
38        "text.color": ".15",
39        "font.sans-serif": ['Arial', 'DejaVu Sans', 'Liberation Sans',
40                            'Bitstream Vera Sans', 'sans-serif'],
41        "legend.fontsize": 11,
42
43        "axes.labelsize": 12,
44        "axes.labelcolor": ".15",
45        "axes.axisbelow": True,
46        "axes.facecolor": "#EAEAF2",
47        "axes.edgecolor": "white",
48        "axes.linewidth": 1.25,
49
50        "grid.linewidth": 1,
51        "grid.color": "white",
52
53        "xtick.labelsize": 11,
54        "xtick.color": ".15",
55        "xtick.major.width": 1.25,
56        "ytick.left": False,
57
58        "lines.solid_capstyle": "round",
59        "patch.edgecolor": "w",
60        "patch.force_edgecolor": True}
61    orig_params = {k: mpl.rcParams[k] for k in params}
62    mpl.rcParams.update(params)
63
64    converter = mpl.colors.colorConverter
65    colors = dict(b="#4C72B0", g="#55A868", r="#C44E52",
66                  m="#8172B3", y="#CCB974", c="#64B5CD",
67                  k=(.1, .1, .1, .1))
68    colors = {k: converter.to_rgb(v) for k, v in colors.items()}
69    orig_colors = converter.colors.copy()
70    orig_cache = converter.cache.copy()
71    converter.colors.update(colors)
72    converter.cache.update(colors)
73
74    import matplotlib.pyplot as plt
75    try:
76        yield plt
77    finally:
78        mpl.rcParams.update(orig_params)
79        converter.colors.clear()
80        converter.colors.update(orig_colors)
81        converter.cache.clear()
82        converter.cache.update(orig_cache)
83