1from statsmodels.compat.python import lrange 2 3from io import BytesIO 4from itertools import product 5 6import numpy as np 7from numpy.testing import assert_, assert_raises 8import pandas as pd 9import pytest 10 11from statsmodels.api import datasets 12 13# utilities for the tests 14 15try: 16 import matplotlib.pyplot as plt # noqa:F401 17except ImportError: 18 pass 19 20# other functions to be tested for accuracy 21# the main drawing function 22from statsmodels.graphics.mosaicplot import ( 23 _hierarchical_split, 24 _key_splitting, 25 _normalize_split, 26 _reduce_dict, 27 _split_rect, 28 mosaic, 29) 30 31 32@pytest.mark.matplotlib 33def test_data_conversion(close_figures): 34 # It will not reorder the elements 35 # so the dictionary will look odd 36 # as it key order has the c and b 37 # keys swapped 38 import pandas 39 _, ax = plt.subplots(4, 4) 40 data = {'ax': 1, 'bx': 2, 'cx': 3} 41 mosaic(data, ax=ax[0, 0], title='basic dict', axes_label=False) 42 data = pandas.Series(data) 43 mosaic(data, ax=ax[0, 1], title='basic series', axes_label=False) 44 data = [1, 2, 3] 45 mosaic(data, ax=ax[0, 2], title='basic list', axes_label=False) 46 data = np.asarray(data) 47 mosaic(data, ax=ax[0, 3], title='basic array', axes_label=False) 48 plt.close("all") 49 50 data = {('ax', 'cx'): 1, ('bx', 'cx'): 2, ('ax', 'dx'): 3, ('bx', 'dx'): 4} 51 mosaic(data, ax=ax[1, 0], title='compound dict', axes_label=False) 52 mosaic(data, ax=ax[2, 0], title='inverted keys dict', index=[1, 0], axes_label=False) 53 data = pandas.Series(data) 54 mosaic(data, ax=ax[1, 1], title='compound series', axes_label=False) 55 mosaic(data, ax=ax[2, 1], title='inverted keys series', index=[1, 0]) 56 data = [[1, 2], [3, 4]] 57 mosaic(data, ax=ax[1, 2], title='compound list', axes_label=False) 58 mosaic(data, ax=ax[2, 2], title='inverted keys list', index=[1, 0]) 59 data = np.array([[1, 2], [3, 4]]) 60 mosaic(data, ax=ax[1, 3], title='compound array', axes_label=False) 61 mosaic(data, ax=ax[2, 3], title='inverted keys array', index=[1, 0], axes_label=False) 62 plt.close("all") 63 64 gender = ['male', 'male', 'male', 'female', 'female', 'female'] 65 pet = ['cat', 'dog', 'dog', 'cat', 'dog', 'cat'] 66 data = pandas.DataFrame({'gender': gender, 'pet': pet}) 67 mosaic(data, ['gender'], ax=ax[3, 0], title='dataframe by key 1', axes_label=False) 68 mosaic(data, ['pet'], ax=ax[3, 1], title='dataframe by key 2', axes_label=False) 69 mosaic(data, ['gender', 'pet'], ax=ax[3, 2], title='both keys', axes_label=False) 70 mosaic(data, ['pet', 'gender'], ax=ax[3, 3], title='keys inverted', axes_label=False) 71 plt.close("all") 72 plt.suptitle('testing data conversion (plot 1 of 4)') 73 74 75@pytest.mark.matplotlib 76def test_mosaic_simple(close_figures): 77 # display a simple plot of 4 categories of data, splitted in four 78 # levels with increasing size for each group 79 # creation of the levels 80 key_set = (['male', 'female'], ['old', 'adult', 'young'], 81 ['worker', 'unemployed'], ['healty', 'ill']) 82 # the cartesian product of all the categories is 83 # the complete set of categories 84 keys = list(product(*key_set)) 85 data = dict(zip(keys, range(1, 1 + len(keys)))) 86 # which colours should I use for the various categories? 87 # put it into a dict 88 props = {} 89 #males and females in blue and red 90 props[('male',)] = {'color': 'b'} 91 props[('female',)] = {'color': 'r'} 92 # all the groups corresponding to ill groups have a different color 93 for key in keys: 94 if 'ill' in key: 95 if 'male' in key: 96 props[key] = {'color': 'BlueViolet' , 'hatch': '+'} 97 else: 98 props[key] = {'color': 'Crimson' , 'hatch': '+'} 99 # mosaic of the data, with given gaps and colors 100 mosaic(data, gap=0.05, properties=props, axes_label=False) 101 plt.suptitle('syntetic data, 4 categories (plot 2 of 4)') 102 103 104@pytest.mark.matplotlib 105def test_mosaic(close_figures): 106 # make the same analysis on a known dataset 107 108 # load the data and clean it a bit 109 affairs = datasets.fair.load_pandas() 110 datas = affairs.exog 111 # any time greater than 0 is cheating 112 datas['cheated'] = affairs.endog > 0 113 # sort by the marriage quality and give meaningful name 114 # [rate_marriage, age, yrs_married, children, 115 # religious, educ, occupation, occupation_husb] 116 datas = datas.sort_values(['rate_marriage', 'religious']) 117 118 num_to_desc = {1: 'awful', 2: 'bad', 3: 'intermediate', 119 4: 'good', 5: 'wonderful'} 120 datas['rate_marriage'] = datas['rate_marriage'].map(num_to_desc) 121 num_to_faith = {1: 'non religious', 2: 'poorly religious', 3: 'religious', 122 4: 'very religious'} 123 datas['religious'] = datas['religious'].map(num_to_faith) 124 num_to_cheat = {False: 'faithful', True: 'cheated'} 125 datas['cheated'] = datas['cheated'].map(num_to_cheat) 126 # finished cleaning 127 _, ax = plt.subplots(2, 2) 128 mosaic(datas, ['rate_marriage', 'cheated'], ax=ax[0, 0], 129 title='by marriage happiness') 130 mosaic(datas, ['religious', 'cheated'], ax=ax[0, 1], 131 title='by religiosity') 132 mosaic(datas, ['rate_marriage', 'religious', 'cheated'], ax=ax[1, 0], 133 title='by both', labelizer=lambda k:'') 134 ax[1, 0].set_xlabel('marriage rating') 135 ax[1, 0].set_ylabel('religion status') 136 mosaic(datas, ['religious', 'rate_marriage'], ax=ax[1, 1], 137 title='inter-dependence', axes_label=False) 138 plt.suptitle("extramarital affairs (plot 3 of 4)") 139 140 141@pytest.mark.matplotlib 142def test_mosaic_very_complex(close_figures): 143 # make a scattermatrix of mosaic plots to show the correlations between 144 # each pair of variable in a dataset. Could be easily converted into a 145 # new function that does this automatically based on the type of data 146 key_name = ['gender', 'age', 'health', 'work'] 147 key_base = (['male', 'female'], ['old', 'young'], 148 ['healty', 'ill'], ['work', 'unemployed']) 149 keys = list(product(*key_base)) 150 data = dict(zip(keys, range(1, 1 + len(keys)))) 151 props = {} 152 props[('male', 'old')] = {'color': 'r'} 153 props[('female',)] = {'color': 'pink'} 154 L = len(key_base) 155 _, axes = plt.subplots(L, L) 156 for i in range(L): 157 for j in range(L): 158 m = set(range(L)).difference(set((i, j))) 159 if i == j: 160 axes[i, i].text(0.5, 0.5, key_name[i], 161 ha='center', va='center') 162 axes[i, i].set_xticks([]) 163 axes[i, i].set_xticklabels([]) 164 axes[i, i].set_yticks([]) 165 axes[i, i].set_yticklabels([]) 166 else: 167 ji = max(i, j) 168 ij = min(i, j) 169 temp_data = dict([((k[ij], k[ji]) + tuple(k[r] for r in m), v) 170 for k, v in data.items()]) 171 172 keys = list(temp_data.keys()) 173 for k in keys: 174 value = _reduce_dict(temp_data, k[:2]) 175 temp_data[k[:2]] = value 176 del temp_data[k] 177 mosaic(temp_data, ax=axes[i, j], axes_label=False, 178 properties=props, gap=0.05, horizontal=i > j) 179 plt.suptitle('old males should look bright red, (plot 4 of 4)') 180 181 182@pytest.mark.matplotlib 183def test_axes_labeling(close_figures): 184 from numpy.random import rand 185 key_set = (['male', 'female'], ['old', 'adult', 'young'], 186 ['worker', 'unemployed'], ['yes', 'no']) 187 # the cartesian product of all the categories is 188 # the complete set of categories 189 keys = list(product(*key_set)) 190 data = dict(zip(keys, rand(len(keys)))) 191 lab = lambda k: ''.join(s[0] for s in k) 192 fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8)) 193 mosaic(data, ax=ax1, labelizer=lab, horizontal=True, label_rotation=45) 194 mosaic(data, ax=ax2, labelizer=lab, horizontal=False, 195 label_rotation=[0, 45, 90, 0]) 196 #fig.tight_layout() 197 fig.suptitle("correct alignment of the axes labels") 198 199 200@pytest.mark.smoke 201@pytest.mark.matplotlib 202def test_mosaic_empty_cells(close_figures): 203 # GH#2286 204 import pandas as pd 205 mydata = pd.DataFrame({'id2': {64: 'Angelica', 206 65: 'DXW_UID', 66: 'casuid01', 207 67: 'casuid01', 68: 'EC93_uid', 208 69: 'EC93_uid', 70: 'EC93_uid', 209 60: 'DXW_UID', 61: 'AtmosFox', 210 62: 'DXW_UID', 63: 'DXW_UID'}, 211 'id1': {64: 'TGP', 212 65: 'Retention01', 66: 'default', 213 67: 'default', 68: 'Musa_EC_9_3', 214 69: 'Musa_EC_9_3', 70: 'Musa_EC_9_3', 215 60: 'default', 61: 'default', 216 62: 'default', 63: 'default'}}) 217 218 ct = pd.crosstab(mydata.id1, mydata.id2) 219 _, vals = mosaic(ct.T.unstack()) 220 _, vals = mosaic(mydata, ['id1','id2']) 221 222 223eq = lambda x, y: assert_(np.allclose(x, y)) 224 225 226def test_recursive_split(): 227 keys = list(product('mf')) 228 data = dict(zip(keys, [1] * len(keys))) 229 res = _hierarchical_split(data, gap=0) 230 assert_(list(res.keys()) == keys) 231 res[('m',)] = (0.0, 0.0, 0.5, 1.0) 232 res[('f',)] = (0.5, 0.0, 0.5, 1.0) 233 keys = list(product('mf', 'yao')) 234 data = dict(zip(keys, [1] * len(keys))) 235 res = _hierarchical_split(data, gap=0) 236 assert_(list(res.keys()) == keys) 237 res[('m', 'y')] = (0.0, 0.0, 0.5, 1 / 3) 238 res[('m', 'a')] = (0.0, 1 / 3, 0.5, 1 / 3) 239 res[('m', 'o')] = (0.0, 2 / 3, 0.5, 1 / 3) 240 res[('f', 'y')] = (0.5, 0.0, 0.5, 1 / 3) 241 res[('f', 'a')] = (0.5, 1 / 3, 0.5, 1 / 3) 242 res[('f', 'o')] = (0.5, 2 / 3, 0.5, 1 / 3) 243 244 245def test__reduce_dict(): 246 data = dict(zip(list(product('mf', 'oy', 'wn')), [1] * 8)) 247 eq(_reduce_dict(data, ('m',)), 4) 248 eq(_reduce_dict(data, ('m', 'o')), 2) 249 eq(_reduce_dict(data, ('m', 'o', 'w')), 1) 250 data = dict(zip(list(product('mf', 'oy', 'wn')), lrange(8))) 251 eq(_reduce_dict(data, ('m',)), 6) 252 eq(_reduce_dict(data, ('m', 'o')), 1) 253 eq(_reduce_dict(data, ('m', 'o', 'w')), 0) 254 255 256def test__key_splitting(): 257 # subdivide starting with an empty tuple 258 base_rect = {tuple(): (0, 0, 1, 1)} 259 res = _key_splitting(base_rect, ['a', 'b'], [1, 1], tuple(), True, 0) 260 assert_(list(res.keys()) == [('a',), ('b',)]) 261 eq(res[('a',)], (0, 0, 0.5, 1)) 262 eq(res[('b',)], (0.5, 0, 0.5, 1)) 263 # subdivide a in two sublevel 264 res_bis = _key_splitting(res, ['c', 'd'], [1, 1], ('a',), False, 0) 265 assert_(list(res_bis.keys()) == [('a', 'c'), ('a', 'd'), ('b',)]) 266 eq(res_bis[('a', 'c')], (0.0, 0.0, 0.5, 0.5)) 267 eq(res_bis[('a', 'd')], (0.0, 0.5, 0.5, 0.5)) 268 eq(res_bis[('b',)], (0.5, 0, 0.5, 1)) 269 # starting with a non empty tuple and uneven distribution 270 base_rect = {('total',): (0, 0, 1, 1)} 271 res = _key_splitting(base_rect, ['a', 'b'], [1, 2], ('total',), True, 0) 272 assert_(list(res.keys()) == [('total',) + (e,) for e in ['a', 'b']]) 273 eq(res[('total', 'a')], (0, 0, 1 / 3, 1)) 274 eq(res[('total', 'b')], (1 / 3, 0, 2 / 3, 1)) 275 276 277def test_proportion_normalization(): 278 # extremes should give the whole set, as well 279 # as if 0 is inserted 280 eq(_normalize_split(0.), [0.0, 0.0, 1.0]) 281 eq(_normalize_split(1.), [0.0, 1.0, 1.0]) 282 eq(_normalize_split(2.), [0.0, 1.0, 1.0]) 283 # negative values should raise ValueError 284 assert_raises(ValueError, _normalize_split, -1) 285 assert_raises(ValueError, _normalize_split, [1., -1]) 286 assert_raises(ValueError, _normalize_split, [1., -1, 0.]) 287 # if everything is zero it will complain 288 assert_raises(ValueError, _normalize_split, [0.]) 289 assert_raises(ValueError, _normalize_split, [0., 0.]) 290 # one-element array should return the whole interval 291 eq(_normalize_split([0.5]), [0.0, 1.0]) 292 eq(_normalize_split([1.]), [0.0, 1.0]) 293 eq(_normalize_split([2.]), [0.0, 1.0]) 294 # simple division should give two pieces 295 for x in [0.3, 0.5, 0.9]: 296 eq(_normalize_split(x), [0., x, 1.0]) 297 # multiple division should split as the sum of the components 298 for x, y in [(0.25, 0.5), (0.1, 0.8), (10., 30.)]: 299 eq(_normalize_split([x, y]), [0., x / (x + y), 1.0]) 300 for x, y, z in [(1., 1., 1.), (0.1, 0.5, 0.7), (10., 30., 40)]: 301 eq(_normalize_split( 302 [x, y, z]), [0., x / (x + y + z), (x + y) / (x + y + z), 1.0]) 303 304 305def test_false_split(): 306 # if you ask it to be divided in only one piece, just return the original 307 # one 308 pure_square = [0., 0., 1., 1.] 309 conf_h = dict(proportion=[1], gap=0.0, horizontal=True) 310 conf_v = dict(proportion=[1], gap=0.0, horizontal=False) 311 eq(_split_rect(*pure_square, **conf_h), pure_square) 312 eq(_split_rect(*pure_square, **conf_v), pure_square) 313 conf_h = dict(proportion=[1], gap=0.5, horizontal=True) 314 conf_v = dict(proportion=[1], gap=0.5, horizontal=False) 315 eq(_split_rect(*pure_square, **conf_h), pure_square) 316 eq(_split_rect(*pure_square, **conf_v), pure_square) 317 318 # identity on a void rectangle should not give anything strange 319 null_square = [0., 0., 0., 0.] 320 conf = dict(proportion=[1], gap=0.0, horizontal=True) 321 eq(_split_rect(*null_square, **conf), null_square) 322 conf = dict(proportion=[1], gap=1.0, horizontal=True) 323 eq(_split_rect(*null_square, **conf), null_square) 324 325 # splitting a negative rectangle should raise error 326 neg_square = [0., 0., -1., 0.] 327 conf = dict(proportion=[1], gap=0.0, horizontal=True) 328 assert_raises(ValueError, _split_rect, *neg_square, **conf) 329 conf = dict(proportion=[1, 1], gap=0.0, horizontal=True) 330 assert_raises(ValueError, _split_rect, *neg_square, **conf) 331 conf = dict(proportion=[1], gap=0.5, horizontal=True) 332 assert_raises(ValueError, _split_rect, *neg_square, **conf) 333 conf = dict(proportion=[1, 1], gap=0.5, horizontal=True) 334 assert_raises(ValueError, _split_rect, *neg_square, **conf) 335 336 337def test_rect_pure_split(): 338 pure_square = [0., 0., 1., 1.] 339 # division in two equal pieces from the perfect square 340 h_2split = [(0.0, 0.0, 0.5, 1.0), (0.5, 0.0, 0.5, 1.0)] 341 conf_h = dict(proportion=[1, 1], gap=0.0, horizontal=True) 342 eq(_split_rect(*pure_square, **conf_h), h_2split) 343 344 v_2split = [(0.0, 0.0, 1.0, 0.5), (0.0, 0.5, 1.0, 0.5)] 345 conf_v = dict(proportion=[1, 1], gap=0.0, horizontal=False) 346 eq(_split_rect(*pure_square, **conf_v), v_2split) 347 348 # division in two non-equal pieces from the perfect square 349 h_2split = [(0.0, 0.0, 1 / 3, 1.0), (1 / 3, 0.0, 2 / 3, 1.0)] 350 conf_h = dict(proportion=[1, 2], gap=0.0, horizontal=True) 351 eq(_split_rect(*pure_square, **conf_h), h_2split) 352 353 v_2split = [(0.0, 0.0, 1.0, 1 / 3), (0.0, 1 / 3, 1.0, 2 / 3)] 354 conf_v = dict(proportion=[1, 2], gap=0.0, horizontal=False) 355 eq(_split_rect(*pure_square, **conf_v), v_2split) 356 357 # division in three equal pieces from the perfect square 358 h_2split = [(0.0, 0.0, 1 / 3, 1.0), (1 / 3, 0.0, 1 / 3, 1.0), (2 / 3, 0.0, 359 1 / 3, 1.0)] 360 conf_h = dict(proportion=[1, 1, 1], gap=0.0, horizontal=True) 361 eq(_split_rect(*pure_square, **conf_h), h_2split) 362 363 v_2split = [(0.0, 0.0, 1.0, 1 / 3), (0.0, 1 / 3, 1.0, 1 / 3), (0.0, 2 / 3, 364 1.0, 1 / 3)] 365 conf_v = dict(proportion=[1, 1, 1], gap=0.0, horizontal=False) 366 eq(_split_rect(*pure_square, **conf_v), v_2split) 367 368 # division in three non-equal pieces from the perfect square 369 h_2split = [(0.0, 0.0, 1 / 4, 1.0), (1 / 4, 0.0, 1 / 2, 1.0), (3 / 4, 0.0, 370 1 / 4, 1.0)] 371 conf_h = dict(proportion=[1, 2, 1], gap=0.0, horizontal=True) 372 eq(_split_rect(*pure_square, **conf_h), h_2split) 373 374 v_2split = [(0.0, 0.0, 1.0, 1 / 4), (0.0, 1 / 4, 1.0, 1 / 2), (0.0, 3 / 4, 375 1.0, 1 / 4)] 376 conf_v = dict(proportion=[1, 2, 1], gap=0.0, horizontal=False) 377 eq(_split_rect(*pure_square, **conf_v), v_2split) 378 379 # splitting on a void rectangle should give multiple void 380 null_square = [0., 0., 0., 0.] 381 conf = dict(proportion=[1, 1], gap=0.0, horizontal=True) 382 eq(_split_rect(*null_square, **conf), [null_square, null_square]) 383 conf = dict(proportion=[1, 2], gap=1.0, horizontal=True) 384 eq(_split_rect(*null_square, **conf), [null_square, null_square]) 385 386 387def test_rect_deformed_split(): 388 non_pure_square = [1., -1., 1., 0.5] 389 # division in two equal pieces from the perfect square 390 h_2split = [(1.0, -1.0, 0.5, 0.5), (1.5, -1.0, 0.5, 0.5)] 391 conf_h = dict(proportion=[1, 1], gap=0.0, horizontal=True) 392 eq(_split_rect(*non_pure_square, **conf_h), h_2split) 393 394 v_2split = [(1.0, -1.0, 1.0, 0.25), (1.0, -0.75, 1.0, 0.25)] 395 conf_v = dict(proportion=[1, 1], gap=0.0, horizontal=False) 396 eq(_split_rect(*non_pure_square, **conf_v), v_2split) 397 398 # division in two non-equal pieces from the perfect square 399 h_2split = [(1.0, -1.0, 1 / 3, 0.5), (1 + 1 / 3, -1.0, 2 / 3, 0.5)] 400 conf_h = dict(proportion=[1, 2], gap=0.0, horizontal=True) 401 eq(_split_rect(*non_pure_square, **conf_h), h_2split) 402 403 v_2split = [(1.0, -1.0, 1.0, 1 / 6), (1.0, 1 / 6 - 1, 1.0, 2 / 6)] 404 conf_v = dict(proportion=[1, 2], gap=0.0, horizontal=False) 405 eq(_split_rect(*non_pure_square, **conf_v), v_2split) 406 407 408def test_gap_split(): 409 pure_square = [0., 0., 1., 1.] 410 411 # null split 412 conf_h = dict(proportion=[1], gap=1.0, horizontal=True) 413 eq(_split_rect(*pure_square, **conf_h), pure_square) 414 415 # equal split 416 h_2split = [(0.0, 0.0, 0.25, 1.0), (0.75, 0.0, 0.25, 1.0)] 417 conf_h = dict(proportion=[1, 1], gap=1.0, horizontal=True) 418 eq(_split_rect(*pure_square, **conf_h), h_2split) 419 420 # disequal split 421 h_2split = [(0.0, 0.0, 1 / 6, 1.0), (0.5 + 1 / 6, 0.0, 1 / 3, 1.0)] 422 conf_h = dict(proportion=[1, 2], gap=1.0, horizontal=True) 423 eq(_split_rect(*pure_square, **conf_h), h_2split) 424 425 426@pytest.mark.matplotlib 427def test_default_arg_index(close_figures): 428 # 2116 429 df = pd.DataFrame({'size' : ['small', 'large', 'large', 'small', 'large', 430 'small'], 431 'length' : ['long', 'short', 'short', 'long', 'long', 432 'short']}) 433 assert_raises(ValueError, mosaic, data=df, title='foobar') 434 435 436@pytest.mark.matplotlib 437def test_missing_category(close_figures): 438 # GH5639 439 animal = ['dog', 'dog', 'dog', 'cat', 'dog', 'cat', 'cat', 440 'dog', 'dog', 'cat'] 441 size = ['medium', 'large', 'medium', 'medium', 'medium', 'medium', 442 'large', 'large', 'large', 'small'] 443 testdata = pd.DataFrame({'animal': animal, 'size': size}) 444 testdata['size'] = pd.Categorical(testdata['size'], 445 categories=['small', 'medium', 'large']) 446 testdata = testdata.sort_values('size') 447 fig, _ = mosaic(testdata, ['animal', 'size']) 448 bio = BytesIO() 449 fig.savefig(bio, format='png') 450