1"""
2This file contains a minimal set of tests for compliance with the extension
3array interface test suite, and should contain no other tests.
4The test suite for the full functionality of the array is located in
5`pandas/tests/arrays/`.
6
7The tests in this file are inherited from the BaseExtensionTests, and only
8minimal tweaks should be applied to get the tests passing (by overwriting a
9parent method).
10
11Additional tests should either be added to one of the BaseExtensionTests
12classes (if they are relevant for the extension interface for all dtypes), or
13be added to the array-specific tests in `pandas/tests/arrays/`.
14
15"""
16import numpy as np
17import pytest
18
19import pandas as pd
20import pandas._testing as tm
21from pandas.core.arrays.boolean import BooleanDtype
22from pandas.tests.extension import base
23
24
25def make_data():
26    return [True, False] * 4 + [np.nan] + [True, False] * 44 + [np.nan] + [True, False]
27
28
29@pytest.fixture
30def dtype():
31    return BooleanDtype()
32
33
34@pytest.fixture
35def data(dtype):
36    return pd.array(make_data(), dtype=dtype)
37
38
39@pytest.fixture
40def data_for_twos(dtype):
41    return pd.array(np.ones(100), dtype=dtype)
42
43
44@pytest.fixture
45def data_missing(dtype):
46    return pd.array([np.nan, True], dtype=dtype)
47
48
49@pytest.fixture
50def data_for_sorting(dtype):
51    return pd.array([True, True, False], dtype=dtype)
52
53
54@pytest.fixture
55def data_missing_for_sorting(dtype):
56    return pd.array([True, np.nan, False], dtype=dtype)
57
58
59@pytest.fixture
60def na_cmp():
61    # we are pd.NA
62    return lambda x, y: x is pd.NA and y is pd.NA
63
64
65@pytest.fixture
66def na_value():
67    return pd.NA
68
69
70@pytest.fixture
71def data_for_grouping(dtype):
72    b = True
73    a = False
74    na = np.nan
75    return pd.array([b, b, na, na, a, a, b], dtype=dtype)
76
77
78class TestDtype(base.BaseDtypeTests):
79    pass
80
81
82class TestInterface(base.BaseInterfaceTests):
83    pass
84
85
86class TestConstructors(base.BaseConstructorsTests):
87    pass
88
89
90class TestGetitem(base.BaseGetitemTests):
91    pass
92
93
94class TestSetitem(base.BaseSetitemTests):
95    pass
96
97
98class TestMissing(base.BaseMissingTests):
99    pass
100
101
102class TestArithmeticOps(base.BaseArithmeticOpsTests):
103    implements = {"__sub__", "__rsub__"}
104
105    def check_opname(self, s, op_name, other, exc=None):
106        # overwriting to indicate ops don't raise an error
107        super().check_opname(s, op_name, other, exc=None)
108
109    def _check_op(self, s, op, other, op_name, exc=NotImplementedError):
110        if exc is None:
111            if op_name in self.implements:
112                msg = r"numpy boolean subtract"
113                with pytest.raises(TypeError, match=msg):
114                    op(s, other)
115                return
116
117            result = op(s, other)
118            expected = s.combine(other, op)
119
120            if op_name in (
121                "__floordiv__",
122                "__rfloordiv__",
123                "__pow__",
124                "__rpow__",
125                "__mod__",
126                "__rmod__",
127            ):
128                # combine keeps boolean type
129                expected = expected.astype("Int8")
130            elif op_name in ("__truediv__", "__rtruediv__"):
131                # combine with bools does not generate the correct result
132                #  (numpy behaviour for div is to regard the bools as numeric)
133                expected = s.astype(float).combine(other, op).astype("Float64")
134            if op_name == "__rpow__":
135                # for rpow, combine does not propagate NaN
136                expected[result.isna()] = np.nan
137            self.assert_series_equal(result, expected)
138        else:
139            with pytest.raises(exc):
140                op(s, other)
141
142    def _check_divmod_op(self, s, op, other, exc=None):
143        # override to not raise an error
144        super()._check_divmod_op(s, op, other, None)
145
146    @pytest.mark.skip(reason="BooleanArray does not error on ops")
147    def test_error(self, data, all_arithmetic_operators):
148        # other specific errors tested in the boolean array specific tests
149        pass
150
151    def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request):
152        # frame & scalar
153        op_name = all_arithmetic_operators
154        if op_name not in self.implements:
155            mark = pytest.mark.xfail(reason="_reduce needs implementation")
156            request.node.add_marker(mark)
157        super().test_arith_frame_with_scalar(data, all_arithmetic_operators)
158
159
160class TestComparisonOps(base.BaseComparisonOpsTests):
161    def check_opname(self, s, op_name, other, exc=None):
162        # overwriting to indicate ops don't raise an error
163        super().check_opname(s, op_name, other, exc=None)
164
165    def _compare_other(self, s, data, op_name, other):
166        self.check_opname(s, op_name, other)
167
168    @pytest.mark.skip(reason="Tested in tests/arrays/test_boolean.py")
169    def test_compare_scalar(self, data, all_compare_operators):
170        pass
171
172    @pytest.mark.skip(reason="Tested in tests/arrays/test_boolean.py")
173    def test_compare_array(self, data, all_compare_operators):
174        pass
175
176
177class TestReshaping(base.BaseReshapingTests):
178    pass
179
180
181class TestMethods(base.BaseMethodsTests):
182    @pytest.mark.parametrize("na_sentinel", [-1, -2])
183    def test_factorize(self, data_for_grouping, na_sentinel):
184        # override because we only have 2 unique values
185        labels, uniques = pd.factorize(data_for_grouping, na_sentinel=na_sentinel)
186        expected_labels = np.array(
187            [0, 0, na_sentinel, na_sentinel, 1, 1, 0], dtype=np.intp
188        )
189        expected_uniques = data_for_grouping.take([0, 4])
190
191        tm.assert_numpy_array_equal(labels, expected_labels)
192        self.assert_extension_array_equal(uniques, expected_uniques)
193
194    def test_combine_le(self, data_repeated):
195        # override because expected needs to be boolean instead of bool dtype
196        orig_data1, orig_data2 = data_repeated(2)
197        s1 = pd.Series(orig_data1)
198        s2 = pd.Series(orig_data2)
199        result = s1.combine(s2, lambda x1, x2: x1 <= x2)
200        expected = pd.Series(
201            [a <= b for (a, b) in zip(list(orig_data1), list(orig_data2))],
202            dtype="boolean",
203        )
204        self.assert_series_equal(result, expected)
205
206        val = s1.iloc[0]
207        result = s1.combine(val, lambda x1, x2: x1 <= x2)
208        expected = pd.Series([a <= val for a in list(orig_data1)], dtype="boolean")
209        self.assert_series_equal(result, expected)
210
211    def test_searchsorted(self, data_for_sorting, as_series):
212        # override because we only have 2 unique values
213        data_for_sorting = pd.array([True, False], dtype="boolean")
214        b, a = data_for_sorting
215        arr = type(data_for_sorting)._from_sequence([a, b])
216
217        if as_series:
218            arr = pd.Series(arr)
219        assert arr.searchsorted(a) == 0
220        assert arr.searchsorted(a, side="right") == 1
221
222        assert arr.searchsorted(b) == 1
223        assert arr.searchsorted(b, side="right") == 2
224
225        result = arr.searchsorted(arr.take([0, 1]))
226        expected = np.array([0, 1], dtype=np.intp)
227
228        tm.assert_numpy_array_equal(result, expected)
229
230        # sorter
231        sorter = np.array([1, 0])
232        assert data_for_sorting.searchsorted(a, sorter=sorter) == 0
233
234    @pytest.mark.skip(reason="uses nullable integer")
235    def test_value_counts(self, all_data, dropna):
236        return super().test_value_counts(all_data, dropna)
237
238    @pytest.mark.skip(reason="uses nullable integer")
239    def test_value_counts_with_normalize(self, data):
240        pass
241
242    def test_argmin_argmax(self, data_for_sorting, data_missing_for_sorting):
243        # override because there are only 2 unique values
244
245        # data_for_sorting -> [B, C, A] with A < B < C -> here True, True, False
246        assert data_for_sorting.argmax() == 0
247        assert data_for_sorting.argmin() == 2
248
249        # with repeated values -> first occurence
250        data = data_for_sorting.take([2, 0, 0, 1, 1, 2])
251        assert data.argmax() == 1
252        assert data.argmin() == 0
253
254        # with missing values
255        # data_missing_for_sorting -> [B, NA, A] with A < B and NA missing.
256        assert data_missing_for_sorting.argmax() == 0
257        assert data_missing_for_sorting.argmin() == 2
258
259
260class TestCasting(base.BaseCastingTests):
261    pass
262
263
264class TestGroupby(base.BaseGroupbyTests):
265    """
266    Groupby-specific tests are overridden because boolean only has 2
267    unique values, base tests uses 3 groups.
268    """
269
270    def test_grouping_grouper(self, data_for_grouping):
271        df = pd.DataFrame(
272            {"A": ["B", "B", None, None, "A", "A", "B"], "B": data_for_grouping}
273        )
274        gr1 = df.groupby("A").grouper.groupings[0]
275        gr2 = df.groupby("B").grouper.groupings[0]
276
277        tm.assert_numpy_array_equal(gr1.grouper, df.A.values)
278        tm.assert_extension_array_equal(gr2.grouper, data_for_grouping)
279
280    @pytest.mark.parametrize("as_index", [True, False])
281    def test_groupby_extension_agg(self, as_index, data_for_grouping):
282        df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1], "B": data_for_grouping})
283        result = df.groupby("B", as_index=as_index).A.mean()
284        _, index = pd.factorize(data_for_grouping, sort=True)
285
286        index = pd.Index(index, name="B")
287        expected = pd.Series([3, 1], index=index, name="A")
288        if as_index:
289            self.assert_series_equal(result, expected)
290        else:
291            expected = expected.reset_index()
292            self.assert_frame_equal(result, expected)
293
294    def test_groupby_agg_extension(self, data_for_grouping):
295        # GH#38980 groupby agg on extension type fails for non-numeric types
296        df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1], "B": data_for_grouping})
297
298        expected = df.iloc[[0, 2, 4]]
299        expected = expected.set_index("A")
300
301        result = df.groupby("A").agg({"B": "first"})
302        self.assert_frame_equal(result, expected)
303
304        result = df.groupby("A").agg("first")
305        self.assert_frame_equal(result, expected)
306
307        result = df.groupby("A").first()
308        self.assert_frame_equal(result, expected)
309
310    def test_groupby_extension_no_sort(self, data_for_grouping):
311        df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1], "B": data_for_grouping})
312        result = df.groupby("B", sort=False).A.mean()
313        _, index = pd.factorize(data_for_grouping, sort=False)
314
315        index = pd.Index(index, name="B")
316        expected = pd.Series([1, 3], index=index, name="A")
317        self.assert_series_equal(result, expected)
318
319    def test_groupby_extension_transform(self, data_for_grouping):
320        valid = data_for_grouping[~data_for_grouping.isna()]
321        df = pd.DataFrame({"A": [1, 1, 3, 3, 1], "B": valid})
322
323        result = df.groupby("B").A.transform(len)
324        expected = pd.Series([3, 3, 2, 2, 3], name="A")
325
326        self.assert_series_equal(result, expected)
327
328    def test_groupby_extension_apply(self, data_for_grouping, groupby_apply_op):
329        df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1], "B": data_for_grouping})
330        df.groupby("B").apply(groupby_apply_op)
331        df.groupby("B").A.apply(groupby_apply_op)
332        df.groupby("A").apply(groupby_apply_op)
333        df.groupby("A").B.apply(groupby_apply_op)
334
335    def test_groupby_apply_identity(self, data_for_grouping):
336        df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1], "B": data_for_grouping})
337        result = df.groupby("A").B.apply(lambda x: x.array)
338        expected = pd.Series(
339            [
340                df.B.iloc[[0, 1, 6]].array,
341                df.B.iloc[[2, 3]].array,
342                df.B.iloc[[4, 5]].array,
343            ],
344            index=pd.Index([1, 2, 3], name="A"),
345            name="B",
346        )
347        self.assert_series_equal(result, expected)
348
349    def test_in_numeric_groupby(self, data_for_grouping):
350        df = pd.DataFrame(
351            {
352                "A": [1, 1, 2, 2, 3, 3, 1],
353                "B": data_for_grouping,
354                "C": [1, 1, 1, 1, 1, 1, 1],
355            }
356        )
357        result = df.groupby("A").sum().columns
358
359        if data_for_grouping.dtype._is_numeric:
360            expected = pd.Index(["B", "C"])
361        else:
362            expected = pd.Index(["C"])
363
364        tm.assert_index_equal(result, expected)
365
366    @pytest.mark.parametrize("min_count", [0, 10])
367    def test_groupby_sum_mincount(self, data_for_grouping, min_count):
368        df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1], "B": data_for_grouping})
369        result = df.groupby("A").sum(min_count=min_count)
370        if min_count == 0:
371            expected = pd.DataFrame(
372                {"B": pd.array([3, 0, 0], dtype="Int64")},
373                index=pd.Index([1, 2, 3], name="A"),
374            )
375            tm.assert_frame_equal(result, expected)
376        else:
377            expected = pd.DataFrame(
378                {"B": pd.array([pd.NA] * 3, dtype="Int64")},
379                index=pd.Index([1, 2, 3], name="A"),
380            )
381            tm.assert_frame_equal(result, expected)
382
383
384class TestNumericReduce(base.BaseNumericReduceTests):
385    def check_reduce(self, s, op_name, skipna):
386        result = getattr(s, op_name)(skipna=skipna)
387        expected = getattr(s.astype("float64"), op_name)(skipna=skipna)
388        # override parent function to cast to bool for min/max
389        if np.isnan(expected):
390            expected = pd.NA
391        elif op_name in ("min", "max"):
392            expected = bool(expected)
393        tm.assert_almost_equal(result, expected)
394
395
396class TestBooleanReduce(base.BaseBooleanReduceTests):
397    pass
398
399
400class TestPrinting(base.BasePrintingTests):
401    pass
402
403
404class TestUnaryOps(base.BaseUnaryOpsTests):
405    pass
406
407
408class TestParsing(base.BaseParsingTests):
409    pass
410