1from statsmodels.compat.python import lrange
2
3from io import BytesIO
4from itertools import product
5
6import numpy as np
7from numpy.testing import assert_, assert_raises
8import pandas as pd
9import pytest
10
11from statsmodels.api import datasets
12
13# utilities for the tests
14
15try:
16    import matplotlib.pyplot as plt  # noqa:F401
17except ImportError:
18    pass
19
20# other functions to be tested for accuracy
21# the main drawing function
22from statsmodels.graphics.mosaicplot import (
23    _hierarchical_split,
24    _key_splitting,
25    _normalize_split,
26    _reduce_dict,
27    _split_rect,
28    mosaic,
29)
30
31
32@pytest.mark.matplotlib
33def test_data_conversion(close_figures):
34    # It will not reorder the elements
35    # so the dictionary will look odd
36    # as it key order has the c and b
37    # keys swapped
38    import pandas
39    _, ax = plt.subplots(4, 4)
40    data = {'ax': 1, 'bx': 2, 'cx': 3}
41    mosaic(data, ax=ax[0, 0], title='basic dict', axes_label=False)
42    data = pandas.Series(data)
43    mosaic(data, ax=ax[0, 1], title='basic series', axes_label=False)
44    data = [1, 2, 3]
45    mosaic(data, ax=ax[0, 2], title='basic list', axes_label=False)
46    data = np.asarray(data)
47    mosaic(data, ax=ax[0, 3], title='basic array', axes_label=False)
48    plt.close("all")
49
50    data = {('ax', 'cx'): 1, ('bx', 'cx'): 2, ('ax', 'dx'): 3, ('bx', 'dx'): 4}
51    mosaic(data, ax=ax[1, 0], title='compound dict', axes_label=False)
52    mosaic(data, ax=ax[2, 0], title='inverted keys dict', index=[1, 0], axes_label=False)
53    data = pandas.Series(data)
54    mosaic(data, ax=ax[1, 1], title='compound series', axes_label=False)
55    mosaic(data, ax=ax[2, 1], title='inverted keys series', index=[1, 0])
56    data = [[1, 2], [3, 4]]
57    mosaic(data, ax=ax[1, 2], title='compound list', axes_label=False)
58    mosaic(data, ax=ax[2, 2], title='inverted keys list', index=[1, 0])
59    data = np.array([[1, 2], [3, 4]])
60    mosaic(data, ax=ax[1, 3], title='compound array', axes_label=False)
61    mosaic(data, ax=ax[2, 3], title='inverted keys array', index=[1, 0], axes_label=False)
62    plt.close("all")
63
64    gender = ['male', 'male', 'male', 'female', 'female', 'female']
65    pet = ['cat', 'dog', 'dog', 'cat', 'dog', 'cat']
66    data = pandas.DataFrame({'gender': gender, 'pet': pet})
67    mosaic(data, ['gender'], ax=ax[3, 0], title='dataframe by key 1', axes_label=False)
68    mosaic(data, ['pet'], ax=ax[3, 1], title='dataframe by key 2', axes_label=False)
69    mosaic(data, ['gender', 'pet'], ax=ax[3, 2], title='both keys', axes_label=False)
70    mosaic(data, ['pet', 'gender'], ax=ax[3, 3], title='keys inverted', axes_label=False)
71    plt.close("all")
72    plt.suptitle('testing data conversion (plot 1 of 4)')
73
74
75@pytest.mark.matplotlib
76def test_mosaic_simple(close_figures):
77    # display a simple plot of 4 categories of data, splitted in four
78    # levels with increasing size for each group
79    # creation of the levels
80    key_set = (['male', 'female'], ['old', 'adult', 'young'],
81               ['worker', 'unemployed'], ['healty', 'ill'])
82    # the cartesian product of all the categories is
83    # the complete set of categories
84    keys = list(product(*key_set))
85    data = dict(zip(keys, range(1, 1 + len(keys))))
86    # which colours should I use for the various categories?
87    # put it into a dict
88    props = {}
89    #males and females in blue and red
90    props[('male',)] = {'color': 'b'}
91    props[('female',)] = {'color': 'r'}
92    # all the groups corresponding to ill groups have a different color
93    for key in keys:
94        if 'ill' in key:
95            if 'male' in key:
96                props[key] = {'color': 'BlueViolet' , 'hatch': '+'}
97            else:
98                props[key] = {'color': 'Crimson' , 'hatch': '+'}
99    # mosaic of the data, with given gaps and colors
100    mosaic(data, gap=0.05, properties=props, axes_label=False)
101    plt.suptitle('syntetic data, 4 categories (plot 2 of 4)')
102
103
104@pytest.mark.matplotlib
105def test_mosaic(close_figures):
106    # make the same analysis on a known dataset
107
108    # load the data and clean it a bit
109    affairs = datasets.fair.load_pandas()
110    datas = affairs.exog
111    # any time greater than 0 is cheating
112    datas['cheated'] = affairs.endog > 0
113    # sort by the marriage quality and give meaningful name
114    # [rate_marriage, age, yrs_married, children,
115    # religious, educ, occupation, occupation_husb]
116    datas = datas.sort_values(['rate_marriage', 'religious'])
117
118    num_to_desc = {1: 'awful', 2: 'bad', 3: 'intermediate',
119                   4: 'good', 5: 'wonderful'}
120    datas['rate_marriage'] = datas['rate_marriage'].map(num_to_desc)
121    num_to_faith = {1: 'non religious', 2: 'poorly religious', 3: 'religious',
122                    4: 'very religious'}
123    datas['religious'] = datas['religious'].map(num_to_faith)
124    num_to_cheat = {False: 'faithful', True: 'cheated'}
125    datas['cheated'] = datas['cheated'].map(num_to_cheat)
126    # finished cleaning
127    _, ax = plt.subplots(2, 2)
128    mosaic(datas, ['rate_marriage', 'cheated'], ax=ax[0, 0],
129           title='by marriage happiness')
130    mosaic(datas, ['religious', 'cheated'], ax=ax[0, 1],
131           title='by religiosity')
132    mosaic(datas, ['rate_marriage', 'religious', 'cheated'], ax=ax[1, 0],
133           title='by both', labelizer=lambda k:'')
134    ax[1, 0].set_xlabel('marriage rating')
135    ax[1, 0].set_ylabel('religion status')
136    mosaic(datas, ['religious', 'rate_marriage'], ax=ax[1, 1],
137           title='inter-dependence', axes_label=False)
138    plt.suptitle("extramarital affairs (plot 3 of 4)")
139
140
141@pytest.mark.matplotlib
142def test_mosaic_very_complex(close_figures):
143    # make a scattermatrix of mosaic plots to show the correlations between
144    # each pair of variable in a dataset. Could be easily converted into a
145    # new function that does this automatically based on the type of data
146    key_name = ['gender', 'age', 'health', 'work']
147    key_base = (['male', 'female'], ['old', 'young'],
148                ['healty', 'ill'], ['work', 'unemployed'])
149    keys = list(product(*key_base))
150    data = dict(zip(keys, range(1, 1 + len(keys))))
151    props = {}
152    props[('male', 'old')] = {'color': 'r'}
153    props[('female',)] = {'color': 'pink'}
154    L = len(key_base)
155    _, axes = plt.subplots(L, L)
156    for i in range(L):
157        for j in range(L):
158            m = set(range(L)).difference(set((i, j)))
159            if i == j:
160                axes[i, i].text(0.5, 0.5, key_name[i],
161                                ha='center', va='center')
162                axes[i, i].set_xticks([])
163                axes[i, i].set_xticklabels([])
164                axes[i, i].set_yticks([])
165                axes[i, i].set_yticklabels([])
166            else:
167                ji = max(i, j)
168                ij = min(i, j)
169                temp_data = dict([((k[ij], k[ji]) + tuple(k[r] for r in m), v)
170                                  for k, v in data.items()])
171
172                keys = list(temp_data.keys())
173                for k in keys:
174                    value = _reduce_dict(temp_data, k[:2])
175                    temp_data[k[:2]] = value
176                    del temp_data[k]
177                mosaic(temp_data, ax=axes[i, j], axes_label=False,
178                       properties=props, gap=0.05, horizontal=i > j)
179    plt.suptitle('old males should look bright red,  (plot 4 of 4)')
180
181
182@pytest.mark.matplotlib
183def test_axes_labeling(close_figures):
184    from numpy.random import rand
185    key_set = (['male', 'female'], ['old', 'adult', 'young'],
186               ['worker', 'unemployed'], ['yes', 'no'])
187    # the cartesian product of all the categories is
188    # the complete set of categories
189    keys = list(product(*key_set))
190    data = dict(zip(keys, rand(len(keys))))
191    lab = lambda k: ''.join(s[0] for s in k)
192    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
193    mosaic(data, ax=ax1, labelizer=lab, horizontal=True, label_rotation=45)
194    mosaic(data, ax=ax2, labelizer=lab, horizontal=False,
195           label_rotation=[0, 45, 90, 0])
196    #fig.tight_layout()
197    fig.suptitle("correct alignment of the axes labels")
198
199
200@pytest.mark.smoke
201@pytest.mark.matplotlib
202def test_mosaic_empty_cells(close_figures):
203    # GH#2286
204    import pandas as pd
205    mydata = pd.DataFrame({'id2': {64: 'Angelica',
206                                   65: 'DXW_UID', 66: 'casuid01',
207                                   67: 'casuid01', 68: 'EC93_uid',
208                                   69: 'EC93_uid', 70: 'EC93_uid',
209                                   60: 'DXW_UID',  61: 'AtmosFox',
210                                   62: 'DXW_UID', 63: 'DXW_UID'},
211                           'id1': {64: 'TGP',
212                                   65: 'Retention01', 66: 'default',
213                                   67: 'default', 68: 'Musa_EC_9_3',
214                                   69: 'Musa_EC_9_3', 70: 'Musa_EC_9_3',
215                                   60: 'default', 61: 'default',
216                                   62: 'default', 63: 'default'}})
217
218    ct = pd.crosstab(mydata.id1, mydata.id2)
219    _, vals = mosaic(ct.T.unstack())
220    _, vals = mosaic(mydata, ['id1','id2'])
221
222
223eq = lambda x, y: assert_(np.allclose(x, y))
224
225
226def test_recursive_split():
227    keys = list(product('mf'))
228    data = dict(zip(keys, [1] * len(keys)))
229    res = _hierarchical_split(data, gap=0)
230    assert_(list(res.keys()) == keys)
231    res[('m',)] = (0.0, 0.0, 0.5, 1.0)
232    res[('f',)] = (0.5, 0.0, 0.5, 1.0)
233    keys = list(product('mf', 'yao'))
234    data = dict(zip(keys, [1] * len(keys)))
235    res = _hierarchical_split(data, gap=0)
236    assert_(list(res.keys()) == keys)
237    res[('m', 'y')] = (0.0, 0.0, 0.5, 1 / 3)
238    res[('m', 'a')] = (0.0, 1 / 3, 0.5, 1 / 3)
239    res[('m', 'o')] = (0.0, 2 / 3, 0.5, 1 / 3)
240    res[('f', 'y')] = (0.5, 0.0, 0.5, 1 / 3)
241    res[('f', 'a')] = (0.5, 1 / 3, 0.5, 1 / 3)
242    res[('f', 'o')] = (0.5, 2 / 3, 0.5, 1 / 3)
243
244
245def test__reduce_dict():
246    data = dict(zip(list(product('mf', 'oy', 'wn')), [1] * 8))
247    eq(_reduce_dict(data, ('m',)), 4)
248    eq(_reduce_dict(data, ('m', 'o')), 2)
249    eq(_reduce_dict(data, ('m', 'o', 'w')), 1)
250    data = dict(zip(list(product('mf', 'oy', 'wn')), lrange(8)))
251    eq(_reduce_dict(data, ('m',)), 6)
252    eq(_reduce_dict(data, ('m', 'o')), 1)
253    eq(_reduce_dict(data, ('m', 'o', 'w')), 0)
254
255
256def test__key_splitting():
257    # subdivide starting with an empty tuple
258    base_rect = {tuple(): (0, 0, 1, 1)}
259    res = _key_splitting(base_rect, ['a', 'b'], [1, 1], tuple(), True, 0)
260    assert_(list(res.keys()) == [('a',), ('b',)])
261    eq(res[('a',)], (0, 0, 0.5, 1))
262    eq(res[('b',)], (0.5, 0, 0.5, 1))
263    # subdivide a in two sublevel
264    res_bis = _key_splitting(res, ['c', 'd'], [1, 1], ('a',), False, 0)
265    assert_(list(res_bis.keys()) == [('a', 'c'), ('a', 'd'), ('b',)])
266    eq(res_bis[('a', 'c')], (0.0, 0.0, 0.5, 0.5))
267    eq(res_bis[('a', 'd')], (0.0, 0.5, 0.5, 0.5))
268    eq(res_bis[('b',)], (0.5, 0, 0.5, 1))
269    # starting with a non empty tuple and uneven distribution
270    base_rect = {('total',): (0, 0, 1, 1)}
271    res = _key_splitting(base_rect, ['a', 'b'], [1, 2], ('total',), True, 0)
272    assert_(list(res.keys()) == [('total',) + (e,) for e in ['a', 'b']])
273    eq(res[('total', 'a')], (0, 0, 1 / 3, 1))
274    eq(res[('total', 'b')], (1 / 3, 0, 2 / 3, 1))
275
276
277def test_proportion_normalization():
278    # extremes should give the whole set, as well
279    # as if 0 is inserted
280    eq(_normalize_split(0.), [0.0, 0.0, 1.0])
281    eq(_normalize_split(1.), [0.0, 1.0, 1.0])
282    eq(_normalize_split(2.), [0.0, 1.0, 1.0])
283    # negative values should raise ValueError
284    assert_raises(ValueError, _normalize_split, -1)
285    assert_raises(ValueError, _normalize_split, [1., -1])
286    assert_raises(ValueError, _normalize_split, [1., -1, 0.])
287    # if everything is zero it will complain
288    assert_raises(ValueError, _normalize_split, [0.])
289    assert_raises(ValueError, _normalize_split, [0., 0.])
290    # one-element array should return the whole interval
291    eq(_normalize_split([0.5]), [0.0, 1.0])
292    eq(_normalize_split([1.]), [0.0, 1.0])
293    eq(_normalize_split([2.]), [0.0, 1.0])
294    # simple division should give two pieces
295    for x in [0.3, 0.5, 0.9]:
296        eq(_normalize_split(x), [0., x, 1.0])
297    # multiple division should split as the sum of the components
298    for x, y in [(0.25, 0.5), (0.1, 0.8), (10., 30.)]:
299        eq(_normalize_split([x, y]), [0., x / (x + y), 1.0])
300    for x, y, z in [(1., 1., 1.), (0.1, 0.5, 0.7), (10., 30., 40)]:
301        eq(_normalize_split(
302            [x, y, z]), [0., x / (x + y + z), (x + y) / (x + y + z), 1.0])
303
304
305def test_false_split():
306    # if you ask it to be divided in only one piece, just return the original
307    # one
308    pure_square = [0., 0., 1., 1.]
309    conf_h = dict(proportion=[1], gap=0.0, horizontal=True)
310    conf_v = dict(proportion=[1], gap=0.0, horizontal=False)
311    eq(_split_rect(*pure_square, **conf_h), pure_square)
312    eq(_split_rect(*pure_square, **conf_v), pure_square)
313    conf_h = dict(proportion=[1], gap=0.5, horizontal=True)
314    conf_v = dict(proportion=[1], gap=0.5, horizontal=False)
315    eq(_split_rect(*pure_square, **conf_h), pure_square)
316    eq(_split_rect(*pure_square, **conf_v), pure_square)
317
318    # identity on a void rectangle should not give anything strange
319    null_square = [0., 0., 0., 0.]
320    conf = dict(proportion=[1], gap=0.0, horizontal=True)
321    eq(_split_rect(*null_square, **conf), null_square)
322    conf = dict(proportion=[1], gap=1.0, horizontal=True)
323    eq(_split_rect(*null_square, **conf), null_square)
324
325    # splitting a negative rectangle should raise error
326    neg_square = [0., 0., -1., 0.]
327    conf = dict(proportion=[1], gap=0.0, horizontal=True)
328    assert_raises(ValueError, _split_rect, *neg_square, **conf)
329    conf = dict(proportion=[1, 1], gap=0.0, horizontal=True)
330    assert_raises(ValueError, _split_rect, *neg_square, **conf)
331    conf = dict(proportion=[1], gap=0.5, horizontal=True)
332    assert_raises(ValueError, _split_rect, *neg_square, **conf)
333    conf = dict(proportion=[1, 1], gap=0.5, horizontal=True)
334    assert_raises(ValueError, _split_rect, *neg_square, **conf)
335
336
337def test_rect_pure_split():
338    pure_square = [0., 0., 1., 1.]
339    # division in two equal pieces from the perfect square
340    h_2split = [(0.0, 0.0, 0.5, 1.0), (0.5, 0.0, 0.5, 1.0)]
341    conf_h = dict(proportion=[1, 1], gap=0.0, horizontal=True)
342    eq(_split_rect(*pure_square, **conf_h), h_2split)
343
344    v_2split = [(0.0, 0.0, 1.0, 0.5), (0.0, 0.5, 1.0, 0.5)]
345    conf_v = dict(proportion=[1, 1], gap=0.0, horizontal=False)
346    eq(_split_rect(*pure_square, **conf_v), v_2split)
347
348    # division in two non-equal pieces from the perfect square
349    h_2split = [(0.0, 0.0, 1 / 3, 1.0), (1 / 3, 0.0, 2 / 3, 1.0)]
350    conf_h = dict(proportion=[1, 2], gap=0.0, horizontal=True)
351    eq(_split_rect(*pure_square, **conf_h), h_2split)
352
353    v_2split = [(0.0, 0.0, 1.0, 1 / 3), (0.0, 1 / 3, 1.0, 2 / 3)]
354    conf_v = dict(proportion=[1, 2], gap=0.0, horizontal=False)
355    eq(_split_rect(*pure_square, **conf_v), v_2split)
356
357    # division in three equal pieces from the perfect square
358    h_2split = [(0.0, 0.0, 1 / 3, 1.0), (1 / 3, 0.0, 1 / 3, 1.0), (2 / 3, 0.0,
359                 1 / 3, 1.0)]
360    conf_h = dict(proportion=[1, 1, 1], gap=0.0, horizontal=True)
361    eq(_split_rect(*pure_square, **conf_h), h_2split)
362
363    v_2split = [(0.0, 0.0, 1.0, 1 / 3), (0.0, 1 / 3, 1.0, 1 / 3), (0.0, 2 / 3,
364                 1.0, 1 / 3)]
365    conf_v = dict(proportion=[1, 1, 1], gap=0.0, horizontal=False)
366    eq(_split_rect(*pure_square, **conf_v), v_2split)
367
368    # division in three non-equal pieces from the perfect square
369    h_2split = [(0.0, 0.0, 1 / 4, 1.0), (1 / 4, 0.0, 1 / 2, 1.0), (3 / 4, 0.0,
370                 1 / 4, 1.0)]
371    conf_h = dict(proportion=[1, 2, 1], gap=0.0, horizontal=True)
372    eq(_split_rect(*pure_square, **conf_h), h_2split)
373
374    v_2split = [(0.0, 0.0, 1.0, 1 / 4), (0.0, 1 / 4, 1.0, 1 / 2), (0.0, 3 / 4,
375                 1.0, 1 / 4)]
376    conf_v = dict(proportion=[1, 2, 1], gap=0.0, horizontal=False)
377    eq(_split_rect(*pure_square, **conf_v), v_2split)
378
379    # splitting on a void rectangle should give multiple void
380    null_square = [0., 0., 0., 0.]
381    conf = dict(proportion=[1, 1], gap=0.0, horizontal=True)
382    eq(_split_rect(*null_square, **conf), [null_square, null_square])
383    conf = dict(proportion=[1, 2], gap=1.0, horizontal=True)
384    eq(_split_rect(*null_square, **conf), [null_square, null_square])
385
386
387def test_rect_deformed_split():
388    non_pure_square = [1., -1., 1., 0.5]
389    # division in two equal pieces from the perfect square
390    h_2split = [(1.0, -1.0, 0.5, 0.5), (1.5, -1.0, 0.5, 0.5)]
391    conf_h = dict(proportion=[1, 1], gap=0.0, horizontal=True)
392    eq(_split_rect(*non_pure_square, **conf_h), h_2split)
393
394    v_2split = [(1.0, -1.0, 1.0, 0.25), (1.0, -0.75, 1.0, 0.25)]
395    conf_v = dict(proportion=[1, 1], gap=0.0, horizontal=False)
396    eq(_split_rect(*non_pure_square, **conf_v), v_2split)
397
398    # division in two non-equal pieces from the perfect square
399    h_2split = [(1.0, -1.0, 1 / 3, 0.5), (1 + 1 / 3, -1.0, 2 / 3, 0.5)]
400    conf_h = dict(proportion=[1, 2], gap=0.0, horizontal=True)
401    eq(_split_rect(*non_pure_square, **conf_h), h_2split)
402
403    v_2split = [(1.0, -1.0, 1.0, 1 / 6), (1.0, 1 / 6 - 1, 1.0, 2 / 6)]
404    conf_v = dict(proportion=[1, 2], gap=0.0, horizontal=False)
405    eq(_split_rect(*non_pure_square, **conf_v), v_2split)
406
407
408def test_gap_split():
409    pure_square = [0., 0., 1., 1.]
410
411    # null split
412    conf_h = dict(proportion=[1], gap=1.0, horizontal=True)
413    eq(_split_rect(*pure_square, **conf_h), pure_square)
414
415    # equal split
416    h_2split = [(0.0, 0.0, 0.25, 1.0), (0.75, 0.0, 0.25, 1.0)]
417    conf_h = dict(proportion=[1, 1], gap=1.0, horizontal=True)
418    eq(_split_rect(*pure_square, **conf_h), h_2split)
419
420    # disequal split
421    h_2split = [(0.0, 0.0, 1 / 6, 1.0), (0.5 + 1 / 6, 0.0, 1 / 3, 1.0)]
422    conf_h = dict(proportion=[1, 2], gap=1.0, horizontal=True)
423    eq(_split_rect(*pure_square, **conf_h), h_2split)
424
425
426@pytest.mark.matplotlib
427def test_default_arg_index(close_figures):
428    # 2116
429    df = pd.DataFrame({'size' : ['small', 'large', 'large', 'small', 'large',
430                                 'small'],
431                       'length' : ['long', 'short', 'short', 'long', 'long',
432                                   'short']})
433    assert_raises(ValueError, mosaic, data=df, title='foobar')
434
435
436@pytest.mark.matplotlib
437def test_missing_category(close_figures):
438    # GH5639
439    animal = ['dog', 'dog', 'dog', 'cat', 'dog', 'cat', 'cat',
440              'dog', 'dog', 'cat']
441    size = ['medium', 'large', 'medium', 'medium', 'medium', 'medium',
442            'large', 'large', 'large', 'small']
443    testdata = pd.DataFrame({'animal': animal, 'size': size})
444    testdata['size'] = pd.Categorical(testdata['size'],
445                                      categories=['small', 'medium', 'large'])
446    testdata = testdata.sort_values('size')
447    fig, _ = mosaic(testdata, ['animal', 'size'])
448    bio = BytesIO()
449    fig.savefig(bio, format='png')
450