1""" common utilities """ 2import itertools 3 4import numpy as np 5 6from pandas import DataFrame, Float64Index, MultiIndex, Series, UInt64Index, date_range 7import pandas._testing as tm 8 9 10def _mklbl(prefix, n): 11 return [f"{prefix}{i}" for i in range(n)] 12 13 14def _axify(obj, key, axis): 15 # create a tuple accessor 16 axes = [slice(None)] * obj.ndim 17 axes[axis] = key 18 return tuple(axes) 19 20 21class Base: 22 """ indexing comprehensive base class """ 23 24 _kinds = {"series", "frame"} 25 _typs = { 26 "ints", 27 "uints", 28 "labels", 29 "mixed", 30 "ts", 31 "floats", 32 "empty", 33 "ts_rev", 34 "multi", 35 } 36 37 def setup_method(self, method): 38 39 self.series_ints = Series(np.random.rand(4), index=np.arange(0, 8, 2)) 40 self.frame_ints = DataFrame( 41 np.random.randn(4, 4), index=np.arange(0, 8, 2), columns=np.arange(0, 12, 3) 42 ) 43 44 self.series_uints = Series( 45 np.random.rand(4), index=UInt64Index(np.arange(0, 8, 2)) 46 ) 47 self.frame_uints = DataFrame( 48 np.random.randn(4, 4), 49 index=UInt64Index(range(0, 8, 2)), 50 columns=UInt64Index(range(0, 12, 3)), 51 ) 52 53 self.series_floats = Series( 54 np.random.rand(4), index=Float64Index(range(0, 8, 2)) 55 ) 56 self.frame_floats = DataFrame( 57 np.random.randn(4, 4), 58 index=Float64Index(range(0, 8, 2)), 59 columns=Float64Index(range(0, 12, 3)), 60 ) 61 62 m_idces = [ 63 MultiIndex.from_product([[1, 2], [3, 4]]), 64 MultiIndex.from_product([[5, 6], [7, 8]]), 65 MultiIndex.from_product([[9, 10], [11, 12]]), 66 ] 67 68 self.series_multi = Series(np.random.rand(4), index=m_idces[0]) 69 self.frame_multi = DataFrame( 70 np.random.randn(4, 4), index=m_idces[0], columns=m_idces[1] 71 ) 72 73 self.series_labels = Series(np.random.randn(4), index=list("abcd")) 74 self.frame_labels = DataFrame( 75 np.random.randn(4, 4), index=list("abcd"), columns=list("ABCD") 76 ) 77 78 self.series_mixed = Series(np.random.randn(4), index=[2, 4, "null", 8]) 79 self.frame_mixed = DataFrame(np.random.randn(4, 4), index=[2, 4, "null", 8]) 80 81 self.series_ts = Series( 82 np.random.randn(4), index=date_range("20130101", periods=4) 83 ) 84 self.frame_ts = DataFrame( 85 np.random.randn(4, 4), index=date_range("20130101", periods=4) 86 ) 87 88 dates_rev = date_range("20130101", periods=4).sort_values(ascending=False) 89 self.series_ts_rev = Series(np.random.randn(4), index=dates_rev) 90 self.frame_ts_rev = DataFrame(np.random.randn(4, 4), index=dates_rev) 91 92 self.frame_empty = DataFrame() 93 self.series_empty = Series(dtype=object) 94 95 # form agglomerates 96 for kind in self._kinds: 97 d = {} 98 for typ in self._typs: 99 d[typ] = getattr(self, f"{kind}_{typ}") 100 101 setattr(self, kind, d) 102 103 def generate_indices(self, f, values=False): 104 """ 105 generate the indices 106 if values is True , use the axis values 107 is False, use the range 108 """ 109 axes = f.axes 110 if values: 111 axes = (list(range(len(ax))) for ax in axes) 112 113 return itertools.product(*axes) 114 115 def get_value(self, name, f, i, values=False): 116 """ return the value for the location i """ 117 # check against values 118 if values: 119 return f.values[i] 120 121 elif name == "iat": 122 return f.iloc[i] 123 else: 124 assert name == "at" 125 return f.loc[i] 126 127 def check_values(self, f, func, values=False): 128 129 if f is None: 130 return 131 axes = f.axes 132 indicies = itertools.product(*axes) 133 134 for i in indicies: 135 result = getattr(f, func)[i] 136 137 # check against values 138 if values: 139 expected = f.values[i] 140 else: 141 expected = f 142 for a in reversed(i): 143 expected = expected.__getitem__(a) 144 145 tm.assert_almost_equal(result, expected) 146 147 def check_result(self, method, key, typs=None, axes=None, fails=None): 148 def _eq(axis, obj, key): 149 """ compare equal for these 2 keys """ 150 axified = _axify(obj, key, axis) 151 try: 152 getattr(obj, method).__getitem__(axified) 153 154 except (IndexError, TypeError, KeyError) as detail: 155 156 # if we are in fails, the ok, otherwise raise it 157 if fails is not None: 158 if isinstance(detail, fails): 159 return 160 raise 161 162 if typs is None: 163 typs = self._typs 164 165 if axes is None: 166 axes = [0, 1] 167 else: 168 assert axes in [0, 1] 169 axes = [axes] 170 171 # check 172 for kind in self._kinds: 173 174 d = getattr(self, kind) 175 for ax in axes: 176 for typ in typs: 177 assert typ in self._typs 178 179 obj = d[typ] 180 if ax < obj.ndim: 181 _eq(axis=ax, obj=obj, key=key) 182