1# -*- coding: utf-8 -*-
2"""Catch all for categorical functions"""
3from __future__ import absolute_import, division, print_function
4
5import pytest
6import numpy as np
7
8from matplotlib.axes import Axes
9import matplotlib.pyplot as plt
10import matplotlib.category as cat
11
12# Python2/3 text handling
13_to_str = cat.StrCategoryFormatter._text
14
15
16class TestUnitData(object):
17    test_cases = [('single', (["hello world"], [0])),
18                  ('unicode', (["Здравствуйте мир"], [0])),
19                  ('mixed', (['A', "np.nan", 'B', "3.14", "мир"],
20                             [0, 1, 2, 3, 4]))]
21    ids, data = zip(*test_cases)
22
23    @pytest.mark.parametrize("data, locs", data, ids=ids)
24    def test_unit(self, data, locs):
25        unit = cat.UnitData(data)
26        assert list(unit._mapping.keys()) == data
27        assert list(unit._mapping.values()) == locs
28
29    def test_update(self):
30        data = ['a', 'd']
31        locs = [0, 1]
32
33        data_update = ['b', 'd', 'e']
34        unique_data = ['a', 'd', 'b', 'e']
35        updated_locs = [0, 1, 2, 3]
36
37        unit = cat.UnitData(data)
38        assert list(unit._mapping.keys()) == data
39        assert list(unit._mapping.values()) == locs
40
41        unit.update(data_update)
42        assert list(unit._mapping.keys()) == unique_data
43        assert list(unit._mapping.values()) == updated_locs
44
45    failing_test_cases = [("number", 3.14), ("nan", np.nan),
46                          ("list", [3.14, 12]), ("mixed type", ["A", 2])]
47
48    fids, fdata = zip(*test_cases)
49
50    @pytest.mark.parametrize("fdata", fdata, ids=fids)
51    def test_non_string_fails(self, fdata):
52        with pytest.raises(TypeError):
53            cat.UnitData(fdata)
54
55    @pytest.mark.parametrize("fdata", fdata, ids=fids)
56    def test_non_string_update_fails(self, fdata):
57        unitdata = cat.UnitData()
58        with pytest.raises(TypeError):
59            unitdata.update(fdata)
60
61
62class FakeAxis(object):
63    def __init__(self, units):
64        self.units = units
65
66
67class TestStrCategoryConverter(object):
68    """Based on the pandas conversion and factorization tests:
69
70    ref: /pandas/tseries/tests/test_converter.py
71         /pandas/tests/test_algos.py:TestFactorize
72    """
73    test_cases = [("unicode", ["Здравствуйте мир"]),
74                  ("ascii", ["hello world"]),
75                  ("single", ['a', 'b', 'c']),
76                  ("integer string", ["1", "2"]),
77                  ("single + values>10", ["A", "B", "C", "D", "E", "F", "G",
78                                          "H", "I", "J", "K", "L", "M", "N",
79                                          "O", "P", "Q", "R", "S", "T", "U",
80                                          "V", "W", "X", "Y", "Z"])]
81
82    ids, values = zip(*test_cases)
83
84    failing_test_cases = [("mixed", [3.14, 'A', np.inf]),
85                          ("string integer", ['42', 42])]
86
87    fids, fvalues = zip(*failing_test_cases)
88
89    @pytest.fixture(autouse=True)
90    def mock_axis(self, request):
91        self.cc = cat.StrCategoryConverter()
92        # self.unit should be probably be replaced with real mock unit
93        self.unit = cat.UnitData()
94        self.ax = FakeAxis(self.unit)
95
96    @pytest.mark.parametrize("vals", values, ids=ids)
97    def test_convert(self, vals):
98        np.testing.assert_allclose(self.cc.convert(vals, self.ax.units,
99                                                   self.ax),
100                                   range(len(vals)))
101
102    @pytest.mark.parametrize("value", ["hi", "мир"], ids=["ascii", "unicode"])
103    def test_convert_one_string(self, value):
104        assert self.cc.convert(value, self.unit, self.ax) == 0
105
106    def test_convert_one_number(self):
107        actual = self.cc.convert(0.0, self.unit, self.ax)
108        np.testing.assert_allclose(actual, np.array([0.]))
109
110    def test_convert_float_array(self):
111        data = np.array([1, 2, 3], dtype=float)
112        actual = self.cc.convert(data, self.unit, self.ax)
113        np.testing.assert_allclose(actual, np.array([1., 2., 3.]))
114
115    @pytest.mark.parametrize("fvals", fvalues, ids=fids)
116    def test_convert_fail(self, fvals):
117        with pytest.raises(TypeError):
118            self.cc.convert(fvals, self.unit, self.ax)
119
120    def test_axisinfo(self):
121        axis = self.cc.axisinfo(self.unit, self.ax)
122        assert isinstance(axis.majloc, cat.StrCategoryLocator)
123        assert isinstance(axis.majfmt, cat.StrCategoryFormatter)
124
125    def test_default_units(self):
126        assert isinstance(self.cc.default_units(["a"], self.ax), cat.UnitData)
127
128
129@pytest.fixture
130def ax():
131    return plt.figure().subplots()
132
133
134PLOT_LIST = [Axes.scatter, Axes.plot, Axes.bar]
135PLOT_IDS = ["scatter", "plot", "bar"]
136
137
138class TestStrCategoryLocator(object):
139    def test_StrCategoryLocator(self):
140        locs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
141        unit = cat.UnitData([str(j) for j in locs])
142        ticks = cat.StrCategoryLocator(unit._mapping)
143        np.testing.assert_array_equal(ticks.tick_values(None, None), locs)
144
145    @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS)
146    def test_StrCategoryLocatorPlot(self, ax, plotter):
147        ax.plot(["a", "b", "c"])
148        np.testing.assert_array_equal(ax.yaxis.major.locator(), range(3))
149
150
151class TestStrCategoryFormatter(object):
152    test_cases = [("ascii", ["hello", "world", "hi"]),
153                  ("unicode", ["Здравствуйте", "привет"])]
154
155    ids, cases = zip(*test_cases)
156
157    @pytest.mark.parametrize("ydata", cases, ids=ids)
158    def test_StrCategoryFormatter(self, ax, ydata):
159        unit = cat.UnitData(ydata)
160        labels = cat.StrCategoryFormatter(unit._mapping)
161        for i, d in enumerate(ydata):
162            assert labels(i, i) == _to_str(d)
163
164    @pytest.mark.parametrize("ydata", cases, ids=ids)
165    @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS)
166    def test_StrCategoryFormatterPlot(self, ax, ydata, plotter):
167        plotter(ax, range(len(ydata)), ydata)
168        for i, d in enumerate(ydata):
169            assert ax.yaxis.major.formatter(i, i) == _to_str(d)
170        assert ax.yaxis.major.formatter(i+1, i+1) == ""
171        assert ax.yaxis.major.formatter(0, None) == ""
172
173
174def axis_test(axis, labels):
175    ticks = list(range(len(labels)))
176    np.testing.assert_array_equal(axis.get_majorticklocs(), ticks)
177    graph_labels = [axis.major.formatter(i, i) for i in ticks]
178    assert graph_labels == [_to_str(l) for l in labels]
179    assert list(axis.units._mapping.keys()) == [l for l in labels]
180    assert list(axis.units._mapping.values()) == ticks
181
182
183class TestPlotBytes(object):
184    bytes_cases = [('string list', ['a', 'b', 'c']),
185                   ('bytes list', [b'a', b'b', b'c']),
186                   ('bytes ndarray', np.array([b'a', b'b', b'c']))]
187
188    bytes_ids, bytes_data = zip(*bytes_cases)
189
190    @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS)
191    @pytest.mark.parametrize("bdata", bytes_data, ids=bytes_ids)
192    def test_plot_bytes(self, ax, plotter, bdata):
193        counts = np.array([4, 6, 5])
194        plotter(ax, bdata, counts)
195        axis_test(ax.xaxis, bdata)
196
197
198class TestPlotNumlike(object):
199    numlike_cases = [('string list', ['1', '11', '3']),
200                     ('string ndarray', np.array(['1', '11', '3'])),
201                     ('bytes list', [b'1', b'11', b'3']),
202                     ('bytes ndarray', np.array([b'1', b'11', b'3']))]
203    numlike_ids, numlike_data = zip(*numlike_cases)
204
205    @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS)
206    @pytest.mark.parametrize("ndata", numlike_data, ids=numlike_ids)
207    def test_plot_numlike(self, ax, plotter, ndata):
208        counts = np.array([4, 6, 5])
209        plotter(ax, ndata, counts)
210        axis_test(ax.xaxis, ndata)
211
212
213class TestPlotTypes(object):
214    @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS)
215    def test_plot_unicode(self, ax, plotter):
216        words = ['Здравствуйте', 'привет']
217        plotter(ax, words, [0, 1])
218        axis_test(ax.xaxis, words)
219
220    @pytest.fixture
221    def test_data(self):
222        self.x = ["hello", "happy", "world"]
223        self.xy = [2, 6, 3]
224        self.y = ["Python", "is", "fun"]
225        self.yx = [3, 4, 5]
226
227    @pytest.mark.usefixtures("test_data")
228    @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS)
229    def test_plot_xaxis(self, ax, test_data, plotter):
230        plotter(ax, self.x, self.xy)
231        axis_test(ax.xaxis, self.x)
232
233    @pytest.mark.usefixtures("test_data")
234    @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS)
235    def test_plot_yaxis(self, ax, test_data, plotter):
236        plotter(ax, self.yx, self.y)
237        axis_test(ax.yaxis, self.y)
238
239    @pytest.mark.usefixtures("test_data")
240    @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS)
241    def test_plot_xyaxis(self, ax, test_data, plotter):
242        plotter(ax, self.x, self.y)
243        axis_test(ax.xaxis, self.x)
244        axis_test(ax.yaxis, self.y)
245
246    @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS)
247    def test_update_plot(self, ax, plotter):
248        plotter(ax, ['a', 'b'], ['e', 'g'])
249        plotter(ax, ['a', 'b', 'd'], ['f', 'a', 'b'])
250        plotter(ax, ['b', 'c', 'd'], ['g', 'e', 'd'])
251        axis_test(ax.xaxis, ['a', 'b', 'd', 'c'])
252        axis_test(ax.yaxis, ['e', 'g', 'f', 'a', 'b', 'd'])
253
254    failing_test_cases = [("mixed", ['A', 3.14]),
255                          ("number integer", ['1', 1]),
256                          ("string integer", ['42', 42]),
257                          ("missing", ['12', np.nan])]
258
259    fids, fvalues = zip(*failing_test_cases)
260
261    PLOT_BROKEN_LIST = [Axes.scatter,
262                        pytest.param(Axes.plot, marks=pytest.mark.xfail),
263                        pytest.param(Axes.bar, marks=pytest.mark.xfail)]
264
265    PLOT_BROKEN_IDS = ["scatter", "plot", "bar"]
266
267    @pytest.mark.parametrize("plotter", PLOT_BROKEN_LIST, ids=PLOT_BROKEN_IDS)
268    @pytest.mark.parametrize("xdata", fvalues, ids=fids)
269    def test_mixed_type_exception(self, ax, plotter, xdata):
270        with pytest.raises(TypeError):
271            plotter(ax, xdata, [1, 2])
272
273    @pytest.mark.parametrize("plotter", PLOT_BROKEN_LIST, ids=PLOT_BROKEN_IDS)
274    @pytest.mark.parametrize("xdata", fvalues, ids=fids)
275    def test_mixed_type_update_exception(self, ax, plotter, xdata):
276        with pytest.raises(TypeError):
277            plotter(ax, [0, 3], [1, 3])
278            plotter(ax, xdata, [1, 2])
279