1# -*- coding: utf-8 -*- 2"""Catch all for categorical functions""" 3from __future__ import absolute_import, division, print_function 4 5import pytest 6import numpy as np 7 8from matplotlib.axes import Axes 9import matplotlib.pyplot as plt 10import matplotlib.category as cat 11 12# Python2/3 text handling 13_to_str = cat.StrCategoryFormatter._text 14 15 16class TestUnitData(object): 17 test_cases = [('single', (["hello world"], [0])), 18 ('unicode', (["Здравствуйте мир"], [0])), 19 ('mixed', (['A', "np.nan", 'B', "3.14", "мир"], 20 [0, 1, 2, 3, 4]))] 21 ids, data = zip(*test_cases) 22 23 @pytest.mark.parametrize("data, locs", data, ids=ids) 24 def test_unit(self, data, locs): 25 unit = cat.UnitData(data) 26 assert list(unit._mapping.keys()) == data 27 assert list(unit._mapping.values()) == locs 28 29 def test_update(self): 30 data = ['a', 'd'] 31 locs = [0, 1] 32 33 data_update = ['b', 'd', 'e'] 34 unique_data = ['a', 'd', 'b', 'e'] 35 updated_locs = [0, 1, 2, 3] 36 37 unit = cat.UnitData(data) 38 assert list(unit._mapping.keys()) == data 39 assert list(unit._mapping.values()) == locs 40 41 unit.update(data_update) 42 assert list(unit._mapping.keys()) == unique_data 43 assert list(unit._mapping.values()) == updated_locs 44 45 failing_test_cases = [("number", 3.14), ("nan", np.nan), 46 ("list", [3.14, 12]), ("mixed type", ["A", 2])] 47 48 fids, fdata = zip(*test_cases) 49 50 @pytest.mark.parametrize("fdata", fdata, ids=fids) 51 def test_non_string_fails(self, fdata): 52 with pytest.raises(TypeError): 53 cat.UnitData(fdata) 54 55 @pytest.mark.parametrize("fdata", fdata, ids=fids) 56 def test_non_string_update_fails(self, fdata): 57 unitdata = cat.UnitData() 58 with pytest.raises(TypeError): 59 unitdata.update(fdata) 60 61 62class FakeAxis(object): 63 def __init__(self, units): 64 self.units = units 65 66 67class TestStrCategoryConverter(object): 68 """Based on the pandas conversion and factorization tests: 69 70 ref: /pandas/tseries/tests/test_converter.py 71 /pandas/tests/test_algos.py:TestFactorize 72 """ 73 test_cases = [("unicode", ["Здравствуйте мир"]), 74 ("ascii", ["hello world"]), 75 ("single", ['a', 'b', 'c']), 76 ("integer string", ["1", "2"]), 77 ("single + values>10", ["A", "B", "C", "D", "E", "F", "G", 78 "H", "I", "J", "K", "L", "M", "N", 79 "O", "P", "Q", "R", "S", "T", "U", 80 "V", "W", "X", "Y", "Z"])] 81 82 ids, values = zip(*test_cases) 83 84 failing_test_cases = [("mixed", [3.14, 'A', np.inf]), 85 ("string integer", ['42', 42])] 86 87 fids, fvalues = zip(*failing_test_cases) 88 89 @pytest.fixture(autouse=True) 90 def mock_axis(self, request): 91 self.cc = cat.StrCategoryConverter() 92 # self.unit should be probably be replaced with real mock unit 93 self.unit = cat.UnitData() 94 self.ax = FakeAxis(self.unit) 95 96 @pytest.mark.parametrize("vals", values, ids=ids) 97 def test_convert(self, vals): 98 np.testing.assert_allclose(self.cc.convert(vals, self.ax.units, 99 self.ax), 100 range(len(vals))) 101 102 @pytest.mark.parametrize("value", ["hi", "мир"], ids=["ascii", "unicode"]) 103 def test_convert_one_string(self, value): 104 assert self.cc.convert(value, self.unit, self.ax) == 0 105 106 def test_convert_one_number(self): 107 actual = self.cc.convert(0.0, self.unit, self.ax) 108 np.testing.assert_allclose(actual, np.array([0.])) 109 110 def test_convert_float_array(self): 111 data = np.array([1, 2, 3], dtype=float) 112 actual = self.cc.convert(data, self.unit, self.ax) 113 np.testing.assert_allclose(actual, np.array([1., 2., 3.])) 114 115 @pytest.mark.parametrize("fvals", fvalues, ids=fids) 116 def test_convert_fail(self, fvals): 117 with pytest.raises(TypeError): 118 self.cc.convert(fvals, self.unit, self.ax) 119 120 def test_axisinfo(self): 121 axis = self.cc.axisinfo(self.unit, self.ax) 122 assert isinstance(axis.majloc, cat.StrCategoryLocator) 123 assert isinstance(axis.majfmt, cat.StrCategoryFormatter) 124 125 def test_default_units(self): 126 assert isinstance(self.cc.default_units(["a"], self.ax), cat.UnitData) 127 128 129@pytest.fixture 130def ax(): 131 return plt.figure().subplots() 132 133 134PLOT_LIST = [Axes.scatter, Axes.plot, Axes.bar] 135PLOT_IDS = ["scatter", "plot", "bar"] 136 137 138class TestStrCategoryLocator(object): 139 def test_StrCategoryLocator(self): 140 locs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 141 unit = cat.UnitData([str(j) for j in locs]) 142 ticks = cat.StrCategoryLocator(unit._mapping) 143 np.testing.assert_array_equal(ticks.tick_values(None, None), locs) 144 145 @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS) 146 def test_StrCategoryLocatorPlot(self, ax, plotter): 147 ax.plot(["a", "b", "c"]) 148 np.testing.assert_array_equal(ax.yaxis.major.locator(), range(3)) 149 150 151class TestStrCategoryFormatter(object): 152 test_cases = [("ascii", ["hello", "world", "hi"]), 153 ("unicode", ["Здравствуйте", "привет"])] 154 155 ids, cases = zip(*test_cases) 156 157 @pytest.mark.parametrize("ydata", cases, ids=ids) 158 def test_StrCategoryFormatter(self, ax, ydata): 159 unit = cat.UnitData(ydata) 160 labels = cat.StrCategoryFormatter(unit._mapping) 161 for i, d in enumerate(ydata): 162 assert labels(i, i) == _to_str(d) 163 164 @pytest.mark.parametrize("ydata", cases, ids=ids) 165 @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS) 166 def test_StrCategoryFormatterPlot(self, ax, ydata, plotter): 167 plotter(ax, range(len(ydata)), ydata) 168 for i, d in enumerate(ydata): 169 assert ax.yaxis.major.formatter(i, i) == _to_str(d) 170 assert ax.yaxis.major.formatter(i+1, i+1) == "" 171 assert ax.yaxis.major.formatter(0, None) == "" 172 173 174def axis_test(axis, labels): 175 ticks = list(range(len(labels))) 176 np.testing.assert_array_equal(axis.get_majorticklocs(), ticks) 177 graph_labels = [axis.major.formatter(i, i) for i in ticks] 178 assert graph_labels == [_to_str(l) for l in labels] 179 assert list(axis.units._mapping.keys()) == [l for l in labels] 180 assert list(axis.units._mapping.values()) == ticks 181 182 183class TestPlotBytes(object): 184 bytes_cases = [('string list', ['a', 'b', 'c']), 185 ('bytes list', [b'a', b'b', b'c']), 186 ('bytes ndarray', np.array([b'a', b'b', b'c']))] 187 188 bytes_ids, bytes_data = zip(*bytes_cases) 189 190 @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS) 191 @pytest.mark.parametrize("bdata", bytes_data, ids=bytes_ids) 192 def test_plot_bytes(self, ax, plotter, bdata): 193 counts = np.array([4, 6, 5]) 194 plotter(ax, bdata, counts) 195 axis_test(ax.xaxis, bdata) 196 197 198class TestPlotNumlike(object): 199 numlike_cases = [('string list', ['1', '11', '3']), 200 ('string ndarray', np.array(['1', '11', '3'])), 201 ('bytes list', [b'1', b'11', b'3']), 202 ('bytes ndarray', np.array([b'1', b'11', b'3']))] 203 numlike_ids, numlike_data = zip(*numlike_cases) 204 205 @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS) 206 @pytest.mark.parametrize("ndata", numlike_data, ids=numlike_ids) 207 def test_plot_numlike(self, ax, plotter, ndata): 208 counts = np.array([4, 6, 5]) 209 plotter(ax, ndata, counts) 210 axis_test(ax.xaxis, ndata) 211 212 213class TestPlotTypes(object): 214 @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS) 215 def test_plot_unicode(self, ax, plotter): 216 words = ['Здравствуйте', 'привет'] 217 plotter(ax, words, [0, 1]) 218 axis_test(ax.xaxis, words) 219 220 @pytest.fixture 221 def test_data(self): 222 self.x = ["hello", "happy", "world"] 223 self.xy = [2, 6, 3] 224 self.y = ["Python", "is", "fun"] 225 self.yx = [3, 4, 5] 226 227 @pytest.mark.usefixtures("test_data") 228 @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS) 229 def test_plot_xaxis(self, ax, test_data, plotter): 230 plotter(ax, self.x, self.xy) 231 axis_test(ax.xaxis, self.x) 232 233 @pytest.mark.usefixtures("test_data") 234 @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS) 235 def test_plot_yaxis(self, ax, test_data, plotter): 236 plotter(ax, self.yx, self.y) 237 axis_test(ax.yaxis, self.y) 238 239 @pytest.mark.usefixtures("test_data") 240 @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS) 241 def test_plot_xyaxis(self, ax, test_data, plotter): 242 plotter(ax, self.x, self.y) 243 axis_test(ax.xaxis, self.x) 244 axis_test(ax.yaxis, self.y) 245 246 @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS) 247 def test_update_plot(self, ax, plotter): 248 plotter(ax, ['a', 'b'], ['e', 'g']) 249 plotter(ax, ['a', 'b', 'd'], ['f', 'a', 'b']) 250 plotter(ax, ['b', 'c', 'd'], ['g', 'e', 'd']) 251 axis_test(ax.xaxis, ['a', 'b', 'd', 'c']) 252 axis_test(ax.yaxis, ['e', 'g', 'f', 'a', 'b', 'd']) 253 254 failing_test_cases = [("mixed", ['A', 3.14]), 255 ("number integer", ['1', 1]), 256 ("string integer", ['42', 42]), 257 ("missing", ['12', np.nan])] 258 259 fids, fvalues = zip(*failing_test_cases) 260 261 PLOT_BROKEN_LIST = [Axes.scatter, 262 pytest.param(Axes.plot, marks=pytest.mark.xfail), 263 pytest.param(Axes.bar, marks=pytest.mark.xfail)] 264 265 PLOT_BROKEN_IDS = ["scatter", "plot", "bar"] 266 267 @pytest.mark.parametrize("plotter", PLOT_BROKEN_LIST, ids=PLOT_BROKEN_IDS) 268 @pytest.mark.parametrize("xdata", fvalues, ids=fids) 269 def test_mixed_type_exception(self, ax, plotter, xdata): 270 with pytest.raises(TypeError): 271 plotter(ax, xdata, [1, 2]) 272 273 @pytest.mark.parametrize("plotter", PLOT_BROKEN_LIST, ids=PLOT_BROKEN_IDS) 274 @pytest.mark.parametrize("xdata", fvalues, ids=fids) 275 def test_mixed_type_update_exception(self, ax, plotter, xdata): 276 with pytest.raises(TypeError): 277 plotter(ax, [0, 3], [1, 3]) 278 plotter(ax, xdata, [1, 2]) 279