1"""
2Note: for naming purposes, most tests are title with as e.g. "test_nlargest_foo"
3but are implicitly also testing nsmallest_foo.
4"""
5from itertools import product
6
7import numpy as np
8import pytest
9
10import pandas as pd
11from pandas import Series
12import pandas._testing as tm
13
14main_dtypes = [
15    "datetime",
16    "datetimetz",
17    "timedelta",
18    "int8",
19    "int16",
20    "int32",
21    "int64",
22    "float32",
23    "float64",
24    "uint8",
25    "uint16",
26    "uint32",
27    "uint64",
28]
29
30
31@pytest.fixture
32def s_main_dtypes():
33    """
34    A DataFrame with many dtypes
35
36    * datetime
37    * datetimetz
38    * timedelta
39    * [u]int{8,16,32,64}
40    * float{32,64}
41
42    The columns are the name of the dtype.
43    """
44    df = pd.DataFrame(
45        {
46            "datetime": pd.to_datetime(["2003", "2002", "2001", "2002", "2005"]),
47            "datetimetz": pd.to_datetime(
48                ["2003", "2002", "2001", "2002", "2005"]
49            ).tz_localize("US/Eastern"),
50            "timedelta": pd.to_timedelta(["3d", "2d", "1d", "2d", "5d"]),
51        }
52    )
53
54    for dtype in [
55        "int8",
56        "int16",
57        "int32",
58        "int64",
59        "float32",
60        "float64",
61        "uint8",
62        "uint16",
63        "uint32",
64        "uint64",
65    ]:
66        df[dtype] = Series([3, 2, 1, 2, 5], dtype=dtype)
67
68    return df
69
70
71@pytest.fixture(params=main_dtypes)
72def s_main_dtypes_split(request, s_main_dtypes):
73    """Each series in s_main_dtypes."""
74    return s_main_dtypes[request.param]
75
76
77def assert_check_nselect_boundary(vals, dtype, method):
78    # helper function for 'test_boundary_{dtype}' tests
79    ser = Series(vals, dtype=dtype)
80    result = getattr(ser, method)(3)
81    expected_idxr = [0, 1, 2] if method == "nsmallest" else [3, 2, 1]
82    expected = ser.loc[expected_idxr]
83    tm.assert_series_equal(result, expected)
84
85
86class TestSeriesNLargestNSmallest:
87    @pytest.mark.parametrize(
88        "r",
89        [
90            Series([3.0, 2, 1, 2, "5"], dtype="object"),
91            Series([3.0, 2, 1, 2, 5], dtype="object"),
92            # not supported on some archs
93            # Series([3., 2, 1, 2, 5], dtype='complex256'),
94            Series([3.0, 2, 1, 2, 5], dtype="complex128"),
95            Series(list("abcde")),
96            Series(list("abcde"), dtype="category"),
97        ],
98    )
99    def test_nlargest_error(self, r):
100        dt = r.dtype
101        msg = f"Cannot use method 'n(larg|small)est' with dtype {dt}"
102        args = 2, len(r), 0, -1
103        methods = r.nlargest, r.nsmallest
104        for method, arg in product(methods, args):
105            with pytest.raises(TypeError, match=msg):
106                method(arg)
107
108    def test_nsmallest_nlargest(self, s_main_dtypes_split):
109        # float, int, datetime64 (use i8), timedelts64 (same),
110        # object that are numbers, object that are strings
111        ser = s_main_dtypes_split
112
113        tm.assert_series_equal(ser.nsmallest(2), ser.iloc[[2, 1]])
114        tm.assert_series_equal(ser.nsmallest(2, keep="last"), ser.iloc[[2, 3]])
115
116        empty = ser.iloc[0:0]
117        tm.assert_series_equal(ser.nsmallest(0), empty)
118        tm.assert_series_equal(ser.nsmallest(-1), empty)
119        tm.assert_series_equal(ser.nlargest(0), empty)
120        tm.assert_series_equal(ser.nlargest(-1), empty)
121
122        tm.assert_series_equal(ser.nsmallest(len(ser)), ser.sort_values())
123        tm.assert_series_equal(ser.nsmallest(len(ser) + 1), ser.sort_values())
124        tm.assert_series_equal(ser.nlargest(len(ser)), ser.iloc[[4, 0, 1, 3, 2]])
125        tm.assert_series_equal(ser.nlargest(len(ser) + 1), ser.iloc[[4, 0, 1, 3, 2]])
126
127    def test_nlargest_misc(self):
128
129        ser = Series([3.0, np.nan, 1, 2, 5])
130        tm.assert_series_equal(ser.nlargest(), ser.iloc[[4, 0, 3, 2]])
131        tm.assert_series_equal(ser.nsmallest(), ser.iloc[[2, 3, 0, 4]])
132
133        msg = 'keep must be either "first", "last"'
134        with pytest.raises(ValueError, match=msg):
135            ser.nsmallest(keep="invalid")
136        with pytest.raises(ValueError, match=msg):
137            ser.nlargest(keep="invalid")
138
139        # GH#15297
140        ser = Series([1] * 5, index=[1, 2, 3, 4, 5])
141        expected_first = Series([1] * 3, index=[1, 2, 3])
142        expected_last = Series([1] * 3, index=[5, 4, 3])
143
144        result = ser.nsmallest(3)
145        tm.assert_series_equal(result, expected_first)
146
147        result = ser.nsmallest(3, keep="last")
148        tm.assert_series_equal(result, expected_last)
149
150        result = ser.nlargest(3)
151        tm.assert_series_equal(result, expected_first)
152
153        result = ser.nlargest(3, keep="last")
154        tm.assert_series_equal(result, expected_last)
155
156    @pytest.mark.parametrize("n", range(1, 5))
157    def test_nlargest_n(self, n):
158
159        # GH 13412
160        ser = Series([1, 4, 3, 2], index=[0, 0, 1, 1])
161        result = ser.nlargest(n)
162        expected = ser.sort_values(ascending=False).head(n)
163        tm.assert_series_equal(result, expected)
164
165        result = ser.nsmallest(n)
166        expected = ser.sort_values().head(n)
167        tm.assert_series_equal(result, expected)
168
169    def test_nlargest_boundary_integer(self, nselect_method, any_int_dtype):
170        # GH#21426
171        dtype_info = np.iinfo(any_int_dtype)
172        min_val, max_val = dtype_info.min, dtype_info.max
173        vals = [min_val, min_val + 1, max_val - 1, max_val]
174        assert_check_nselect_boundary(vals, any_int_dtype, nselect_method)
175
176    def test_nlargest_boundary_float(self, nselect_method, float_dtype):
177        # GH#21426
178        dtype_info = np.finfo(float_dtype)
179        min_val, max_val = dtype_info.min, dtype_info.max
180        min_2nd, max_2nd = np.nextafter([min_val, max_val], 0, dtype=float_dtype)
181        vals = [min_val, min_2nd, max_2nd, max_val]
182        assert_check_nselect_boundary(vals, float_dtype, nselect_method)
183
184    @pytest.mark.parametrize("dtype", ["datetime64[ns]", "timedelta64[ns]"])
185    def test_nlargest_boundary_datetimelike(self, nselect_method, dtype):
186        # GH#21426
187        # use int64 bounds and +1 to min_val since true minimum is NaT
188        # (include min_val/NaT at end to maintain same expected_idxr)
189        dtype_info = np.iinfo("int64")
190        min_val, max_val = dtype_info.min, dtype_info.max
191        vals = [min_val + 1, min_val + 2, max_val - 1, max_val, min_val]
192        assert_check_nselect_boundary(vals, dtype, nselect_method)
193
194    def test_nlargest_duplicate_keep_all_ties(self):
195        # see GH#16818
196        ser = Series([10, 9, 8, 7, 7, 7, 7, 6])
197        result = ser.nlargest(4, keep="all")
198        expected = Series([10, 9, 8, 7, 7, 7, 7])
199        tm.assert_series_equal(result, expected)
200
201        result = ser.nsmallest(2, keep="all")
202        expected = Series([6, 7, 7, 7, 7], index=[7, 3, 4, 5, 6])
203        tm.assert_series_equal(result, expected)
204
205    @pytest.mark.parametrize(
206        "data,expected", [([True, False], [True]), ([True, False, True, True], [True])]
207    )
208    def test_nlargest_boolean(self, data, expected):
209        # GH#26154 : ensure True > False
210        ser = Series(data)
211        result = ser.nlargest(1)
212        expected = Series(expected)
213        tm.assert_series_equal(result, expected)
214