1""" 2Note: for naming purposes, most tests are title with as e.g. "test_nlargest_foo" 3but are implicitly also testing nsmallest_foo. 4""" 5from itertools import product 6 7import numpy as np 8import pytest 9 10import pandas as pd 11from pandas import Series 12import pandas._testing as tm 13 14main_dtypes = [ 15 "datetime", 16 "datetimetz", 17 "timedelta", 18 "int8", 19 "int16", 20 "int32", 21 "int64", 22 "float32", 23 "float64", 24 "uint8", 25 "uint16", 26 "uint32", 27 "uint64", 28] 29 30 31@pytest.fixture 32def s_main_dtypes(): 33 """ 34 A DataFrame with many dtypes 35 36 * datetime 37 * datetimetz 38 * timedelta 39 * [u]int{8,16,32,64} 40 * float{32,64} 41 42 The columns are the name of the dtype. 43 """ 44 df = pd.DataFrame( 45 { 46 "datetime": pd.to_datetime(["2003", "2002", "2001", "2002", "2005"]), 47 "datetimetz": pd.to_datetime( 48 ["2003", "2002", "2001", "2002", "2005"] 49 ).tz_localize("US/Eastern"), 50 "timedelta": pd.to_timedelta(["3d", "2d", "1d", "2d", "5d"]), 51 } 52 ) 53 54 for dtype in [ 55 "int8", 56 "int16", 57 "int32", 58 "int64", 59 "float32", 60 "float64", 61 "uint8", 62 "uint16", 63 "uint32", 64 "uint64", 65 ]: 66 df[dtype] = Series([3, 2, 1, 2, 5], dtype=dtype) 67 68 return df 69 70 71@pytest.fixture(params=main_dtypes) 72def s_main_dtypes_split(request, s_main_dtypes): 73 """Each series in s_main_dtypes.""" 74 return s_main_dtypes[request.param] 75 76 77def assert_check_nselect_boundary(vals, dtype, method): 78 # helper function for 'test_boundary_{dtype}' tests 79 ser = Series(vals, dtype=dtype) 80 result = getattr(ser, method)(3) 81 expected_idxr = [0, 1, 2] if method == "nsmallest" else [3, 2, 1] 82 expected = ser.loc[expected_idxr] 83 tm.assert_series_equal(result, expected) 84 85 86class TestSeriesNLargestNSmallest: 87 @pytest.mark.parametrize( 88 "r", 89 [ 90 Series([3.0, 2, 1, 2, "5"], dtype="object"), 91 Series([3.0, 2, 1, 2, 5], dtype="object"), 92 # not supported on some archs 93 # Series([3., 2, 1, 2, 5], dtype='complex256'), 94 Series([3.0, 2, 1, 2, 5], dtype="complex128"), 95 Series(list("abcde")), 96 Series(list("abcde"), dtype="category"), 97 ], 98 ) 99 def test_nlargest_error(self, r): 100 dt = r.dtype 101 msg = f"Cannot use method 'n(larg|small)est' with dtype {dt}" 102 args = 2, len(r), 0, -1 103 methods = r.nlargest, r.nsmallest 104 for method, arg in product(methods, args): 105 with pytest.raises(TypeError, match=msg): 106 method(arg) 107 108 def test_nsmallest_nlargest(self, s_main_dtypes_split): 109 # float, int, datetime64 (use i8), timedelts64 (same), 110 # object that are numbers, object that are strings 111 ser = s_main_dtypes_split 112 113 tm.assert_series_equal(ser.nsmallest(2), ser.iloc[[2, 1]]) 114 tm.assert_series_equal(ser.nsmallest(2, keep="last"), ser.iloc[[2, 3]]) 115 116 empty = ser.iloc[0:0] 117 tm.assert_series_equal(ser.nsmallest(0), empty) 118 tm.assert_series_equal(ser.nsmallest(-1), empty) 119 tm.assert_series_equal(ser.nlargest(0), empty) 120 tm.assert_series_equal(ser.nlargest(-1), empty) 121 122 tm.assert_series_equal(ser.nsmallest(len(ser)), ser.sort_values()) 123 tm.assert_series_equal(ser.nsmallest(len(ser) + 1), ser.sort_values()) 124 tm.assert_series_equal(ser.nlargest(len(ser)), ser.iloc[[4, 0, 1, 3, 2]]) 125 tm.assert_series_equal(ser.nlargest(len(ser) + 1), ser.iloc[[4, 0, 1, 3, 2]]) 126 127 def test_nlargest_misc(self): 128 129 ser = Series([3.0, np.nan, 1, 2, 5]) 130 tm.assert_series_equal(ser.nlargest(), ser.iloc[[4, 0, 3, 2]]) 131 tm.assert_series_equal(ser.nsmallest(), ser.iloc[[2, 3, 0, 4]]) 132 133 msg = 'keep must be either "first", "last"' 134 with pytest.raises(ValueError, match=msg): 135 ser.nsmallest(keep="invalid") 136 with pytest.raises(ValueError, match=msg): 137 ser.nlargest(keep="invalid") 138 139 # GH#15297 140 ser = Series([1] * 5, index=[1, 2, 3, 4, 5]) 141 expected_first = Series([1] * 3, index=[1, 2, 3]) 142 expected_last = Series([1] * 3, index=[5, 4, 3]) 143 144 result = ser.nsmallest(3) 145 tm.assert_series_equal(result, expected_first) 146 147 result = ser.nsmallest(3, keep="last") 148 tm.assert_series_equal(result, expected_last) 149 150 result = ser.nlargest(3) 151 tm.assert_series_equal(result, expected_first) 152 153 result = ser.nlargest(3, keep="last") 154 tm.assert_series_equal(result, expected_last) 155 156 @pytest.mark.parametrize("n", range(1, 5)) 157 def test_nlargest_n(self, n): 158 159 # GH 13412 160 ser = Series([1, 4, 3, 2], index=[0, 0, 1, 1]) 161 result = ser.nlargest(n) 162 expected = ser.sort_values(ascending=False).head(n) 163 tm.assert_series_equal(result, expected) 164 165 result = ser.nsmallest(n) 166 expected = ser.sort_values().head(n) 167 tm.assert_series_equal(result, expected) 168 169 def test_nlargest_boundary_integer(self, nselect_method, any_int_dtype): 170 # GH#21426 171 dtype_info = np.iinfo(any_int_dtype) 172 min_val, max_val = dtype_info.min, dtype_info.max 173 vals = [min_val, min_val + 1, max_val - 1, max_val] 174 assert_check_nselect_boundary(vals, any_int_dtype, nselect_method) 175 176 def test_nlargest_boundary_float(self, nselect_method, float_dtype): 177 # GH#21426 178 dtype_info = np.finfo(float_dtype) 179 min_val, max_val = dtype_info.min, dtype_info.max 180 min_2nd, max_2nd = np.nextafter([min_val, max_val], 0, dtype=float_dtype) 181 vals = [min_val, min_2nd, max_2nd, max_val] 182 assert_check_nselect_boundary(vals, float_dtype, nselect_method) 183 184 @pytest.mark.parametrize("dtype", ["datetime64[ns]", "timedelta64[ns]"]) 185 def test_nlargest_boundary_datetimelike(self, nselect_method, dtype): 186 # GH#21426 187 # use int64 bounds and +1 to min_val since true minimum is NaT 188 # (include min_val/NaT at end to maintain same expected_idxr) 189 dtype_info = np.iinfo("int64") 190 min_val, max_val = dtype_info.min, dtype_info.max 191 vals = [min_val + 1, min_val + 2, max_val - 1, max_val, min_val] 192 assert_check_nselect_boundary(vals, dtype, nselect_method) 193 194 def test_nlargest_duplicate_keep_all_ties(self): 195 # see GH#16818 196 ser = Series([10, 9, 8, 7, 7, 7, 7, 6]) 197 result = ser.nlargest(4, keep="all") 198 expected = Series([10, 9, 8, 7, 7, 7, 7]) 199 tm.assert_series_equal(result, expected) 200 201 result = ser.nsmallest(2, keep="all") 202 expected = Series([6, 7, 7, 7, 7], index=[7, 3, 4, 5, 6]) 203 tm.assert_series_equal(result, expected) 204 205 @pytest.mark.parametrize( 206 "data,expected", [([True, False], [True]), ([True, False, True, True], [True])] 207 ) 208 def test_nlargest_boolean(self, data, expected): 209 # GH#26154 : ensure True > False 210 ser = Series(data) 211 result = ser.nlargest(1) 212 expected = Series(expected) 213 tm.assert_series_equal(result, expected) 214