1""" common utilities """
2import itertools
3
4import numpy as np
5
6from pandas import DataFrame, Float64Index, MultiIndex, Series, UInt64Index, date_range
7import pandas._testing as tm
8
9
10def _mklbl(prefix, n):
11    return [f"{prefix}{i}" for i in range(n)]
12
13
14def _axify(obj, key, axis):
15    # create a tuple accessor
16    axes = [slice(None)] * obj.ndim
17    axes[axis] = key
18    return tuple(axes)
19
20
21class Base:
22    """ indexing comprehensive base class """
23
24    _kinds = {"series", "frame"}
25    _typs = {
26        "ints",
27        "uints",
28        "labels",
29        "mixed",
30        "ts",
31        "floats",
32        "empty",
33        "ts_rev",
34        "multi",
35    }
36
37    def setup_method(self, method):
38
39        self.series_ints = Series(np.random.rand(4), index=np.arange(0, 8, 2))
40        self.frame_ints = DataFrame(
41            np.random.randn(4, 4), index=np.arange(0, 8, 2), columns=np.arange(0, 12, 3)
42        )
43
44        self.series_uints = Series(
45            np.random.rand(4), index=UInt64Index(np.arange(0, 8, 2))
46        )
47        self.frame_uints = DataFrame(
48            np.random.randn(4, 4),
49            index=UInt64Index(range(0, 8, 2)),
50            columns=UInt64Index(range(0, 12, 3)),
51        )
52
53        self.series_floats = Series(
54            np.random.rand(4), index=Float64Index(range(0, 8, 2))
55        )
56        self.frame_floats = DataFrame(
57            np.random.randn(4, 4),
58            index=Float64Index(range(0, 8, 2)),
59            columns=Float64Index(range(0, 12, 3)),
60        )
61
62        m_idces = [
63            MultiIndex.from_product([[1, 2], [3, 4]]),
64            MultiIndex.from_product([[5, 6], [7, 8]]),
65            MultiIndex.from_product([[9, 10], [11, 12]]),
66        ]
67
68        self.series_multi = Series(np.random.rand(4), index=m_idces[0])
69        self.frame_multi = DataFrame(
70            np.random.randn(4, 4), index=m_idces[0], columns=m_idces[1]
71        )
72
73        self.series_labels = Series(np.random.randn(4), index=list("abcd"))
74        self.frame_labels = DataFrame(
75            np.random.randn(4, 4), index=list("abcd"), columns=list("ABCD")
76        )
77
78        self.series_mixed = Series(np.random.randn(4), index=[2, 4, "null", 8])
79        self.frame_mixed = DataFrame(np.random.randn(4, 4), index=[2, 4, "null", 8])
80
81        self.series_ts = Series(
82            np.random.randn(4), index=date_range("20130101", periods=4)
83        )
84        self.frame_ts = DataFrame(
85            np.random.randn(4, 4), index=date_range("20130101", periods=4)
86        )
87
88        dates_rev = date_range("20130101", periods=4).sort_values(ascending=False)
89        self.series_ts_rev = Series(np.random.randn(4), index=dates_rev)
90        self.frame_ts_rev = DataFrame(np.random.randn(4, 4), index=dates_rev)
91
92        self.frame_empty = DataFrame()
93        self.series_empty = Series(dtype=object)
94
95        # form agglomerates
96        for kind in self._kinds:
97            d = {}
98            for typ in self._typs:
99                d[typ] = getattr(self, f"{kind}_{typ}")
100
101            setattr(self, kind, d)
102
103    def generate_indices(self, f, values=False):
104        """
105        generate the indices
106        if values is True , use the axis values
107        is False, use the range
108        """
109        axes = f.axes
110        if values:
111            axes = (list(range(len(ax))) for ax in axes)
112
113        return itertools.product(*axes)
114
115    def get_value(self, name, f, i, values=False):
116        """ return the value for the location i """
117        # check against values
118        if values:
119            return f.values[i]
120
121        elif name == "iat":
122            return f.iloc[i]
123        else:
124            assert name == "at"
125            return f.loc[i]
126
127    def check_values(self, f, func, values=False):
128
129        if f is None:
130            return
131        axes = f.axes
132        indicies = itertools.product(*axes)
133
134        for i in indicies:
135            result = getattr(f, func)[i]
136
137            # check against values
138            if values:
139                expected = f.values[i]
140            else:
141                expected = f
142                for a in reversed(i):
143                    expected = expected.__getitem__(a)
144
145            tm.assert_almost_equal(result, expected)
146
147    def check_result(self, method, key, typs=None, axes=None, fails=None):
148        def _eq(axis, obj, key):
149            """ compare equal for these 2 keys """
150            axified = _axify(obj, key, axis)
151            try:
152                getattr(obj, method).__getitem__(axified)
153
154            except (IndexError, TypeError, KeyError) as detail:
155
156                # if we are in fails, the ok, otherwise raise it
157                if fails is not None:
158                    if isinstance(detail, fails):
159                        return
160                raise
161
162        if typs is None:
163            typs = self._typs
164
165        if axes is None:
166            axes = [0, 1]
167        else:
168            assert axes in [0, 1]
169            axes = [axes]
170
171        # check
172        for kind in self._kinds:
173
174            d = getattr(self, kind)
175            for ax in axes:
176                for typ in typs:
177                    assert typ in self._typs
178
179                    obj = d[typ]
180                    if ax < obj.ndim:
181                        _eq(axis=ax, obj=obj, key=key)
182