1import contextlib 2 3 4def check_args(x, y, rope=0, prior=1, nsamples=50000): 5 if x.ndim != 1: 6 raise ValueError("'x' must be a 1-dimensional array") 7 if y.ndim != 1: 8 raise ValueError("'y' must be a 1-dimensional array") 9 if len(x) != len(y): 10 raise ValueError("'x' and 'y' must be of same length") 11 if rope < 0: 12 raise ValueError('Rope width cannot be negative') 13 if prior < 0: 14 raise ValueError('Prior strength cannot be negative') 15 if not round(nsamples) == nsamples > 0: 16 raise ValueError('Number of samples must be a positive integer') 17 18 19def call_shortcut(test, x, y, rope, *args, plot=False, names=None, **kwargs): 20 sample = test(x, y, rope, *args, **kwargs) 21 if plot: 22 return sample.probs(), sample.plot(names) 23 else: 24 return sample.probs() 25 26 27@contextlib.contextmanager 28def seaborn_plt(): 29 # Set a Seaborn-like style. See https://github.com/mwaskom/seaborn. 30 try: 31 import matplotlib as mpl 32 except ImportError: 33 raise ImportError("Plotting requires 'matplotlib'; " 34 "use 'pip install matplotlib' to install it") 35 36 params = { 37 "font.size": 12, 38 "text.color": ".15", 39 "font.sans-serif": ['Arial', 'DejaVu Sans', 'Liberation Sans', 40 'Bitstream Vera Sans', 'sans-serif'], 41 "legend.fontsize": 11, 42 43 "axes.labelsize": 12, 44 "axes.labelcolor": ".15", 45 "axes.axisbelow": True, 46 "axes.facecolor": "#EAEAF2", 47 "axes.edgecolor": "white", 48 "axes.linewidth": 1.25, 49 50 "grid.linewidth": 1, 51 "grid.color": "white", 52 53 "xtick.labelsize": 11, 54 "xtick.color": ".15", 55 "xtick.major.width": 1.25, 56 "ytick.left": False, 57 58 "lines.solid_capstyle": "round", 59 "patch.edgecolor": "w", 60 "patch.force_edgecolor": True} 61 orig_params = {k: mpl.rcParams[k] for k in params} 62 mpl.rcParams.update(params) 63 64 converter = mpl.colors.colorConverter 65 colors = dict(b="#4C72B0", g="#55A868", r="#C44E52", 66 m="#8172B3", y="#CCB974", c="#64B5CD", 67 k=(.1, .1, .1, .1)) 68 colors = {k: converter.to_rgb(v) for k, v in colors.items()} 69 orig_colors = converter.colors.copy() 70 orig_cache = converter.cache.copy() 71 converter.colors.update(colors) 72 converter.cache.update(colors) 73 74 import matplotlib.pyplot as plt 75 try: 76 yield plt 77 finally: 78 mpl.rcParams.update(orig_params) 79 converter.colors.clear() 80 converter.colors.update(orig_colors) 81 converter.cache.clear() 82 converter.cache.update(orig_cache) 83