1"""
2This file contains a minimal set of tests for compliance with the extension
3array interface test suite, and should contain no other tests.
4The test suite for the full functionality of the array is located in
5`pandas/tests/arrays/`.
6
7The tests in this file are inherited from the BaseExtensionTests, and only
8minimal tweaks should be applied to get the tests passing (by overwriting a
9parent method).
10
11Additional tests should either be added to one of the BaseExtensionTests
12classes (if they are relevant for the extension interface for all dtypes), or
13be added to the array-specific tests in `pandas/tests/arrays/`.
14
15"""
16import string
17
18import numpy as np
19import pytest
20
21import pandas as pd
22from pandas import Categorical, CategoricalIndex, Timestamp
23import pandas._testing as tm
24from pandas.api.types import CategoricalDtype
25from pandas.tests.extension import base
26
27
28def make_data():
29    while True:
30        values = np.random.choice(list(string.ascii_letters), size=100)
31        # ensure we meet the requirements
32        # 1. first two not null
33        # 2. first and second are different
34        if values[0] != values[1]:
35            break
36    return values
37
38
39@pytest.fixture
40def dtype():
41    return CategoricalDtype()
42
43
44@pytest.fixture
45def data():
46    """Length-100 array for this type.
47
48    * data[0] and data[1] should both be non missing
49    * data[0] and data[1] should not gbe equal
50    """
51    return Categorical(make_data())
52
53
54@pytest.fixture
55def data_missing():
56    """Length 2 array with [NA, Valid]"""
57    return Categorical([np.nan, "A"])
58
59
60@pytest.fixture
61def data_for_sorting():
62    return Categorical(["A", "B", "C"], categories=["C", "A", "B"], ordered=True)
63
64
65@pytest.fixture
66def data_missing_for_sorting():
67    return Categorical(["A", None, "B"], categories=["B", "A"], ordered=True)
68
69
70@pytest.fixture
71def na_value():
72    return np.nan
73
74
75@pytest.fixture
76def data_for_grouping():
77    return Categorical(["a", "a", None, None, "b", "b", "a", "c"])
78
79
80class TestDtype(base.BaseDtypeTests):
81    pass
82
83
84class TestInterface(base.BaseInterfaceTests):
85    @pytest.mark.skip(reason="Memory usage doesn't match")
86    def test_memory_usage(self, data):
87        # Is this deliberate?
88        super().test_memory_usage(data)
89
90    def test_contains(self, data, data_missing):
91        # GH-37867
92        # na value handling in Categorical.__contains__ is deprecated.
93        # See base.BaseInterFaceTests.test_contains for more details.
94
95        na_value = data.dtype.na_value
96        # ensure data without missing values
97        data = data[~data.isna()]
98
99        # first elements are non-missing
100        assert data[0] in data
101        assert data_missing[0] in data_missing
102
103        # check the presence of na_value
104        assert na_value in data_missing
105        assert na_value not in data
106
107        # Categoricals can contain other nan-likes than na_value
108        for na_value_obj in tm.NULL_OBJECTS:
109            if na_value_obj is na_value:
110                continue
111            assert na_value_obj not in data
112            assert na_value_obj in data_missing  # this line differs from super method
113
114
115class TestConstructors(base.BaseConstructorsTests):
116    pass
117
118
119class TestReshaping(base.BaseReshapingTests):
120    def test_concat_with_reindex(self, data):
121        pytest.xfail(reason="Deliberately upcast to object?")
122
123
124class TestGetitem(base.BaseGetitemTests):
125    @pytest.mark.skip(reason="Backwards compatibility")
126    def test_getitem_scalar(self, data):
127        # CategoricalDtype.type isn't "correct" since it should
128        # be a parent of the elements (object). But don't want
129        # to break things by changing.
130        super().test_getitem_scalar(data)
131
132
133class TestSetitem(base.BaseSetitemTests):
134    pass
135
136
137class TestMissing(base.BaseMissingTests):
138    @pytest.mark.skip(reason="Not implemented")
139    def test_fillna_limit_pad(self, data_missing):
140        super().test_fillna_limit_pad(data_missing)
141
142    @pytest.mark.skip(reason="Not implemented")
143    def test_fillna_limit_backfill(self, data_missing):
144        super().test_fillna_limit_backfill(data_missing)
145
146
147class TestReduce(base.BaseNoReduceTests):
148    pass
149
150
151class TestMethods(base.BaseMethodsTests):
152    @pytest.mark.skip(reason="Unobserved categories included")
153    def test_value_counts(self, all_data, dropna):
154        return super().test_value_counts(all_data, dropna)
155
156    def test_combine_add(self, data_repeated):
157        # GH 20825
158        # When adding categoricals in combine, result is a string
159        orig_data1, orig_data2 = data_repeated(2)
160        s1 = pd.Series(orig_data1)
161        s2 = pd.Series(orig_data2)
162        result = s1.combine(s2, lambda x1, x2: x1 + x2)
163        expected = pd.Series(
164            [a + b for (a, b) in zip(list(orig_data1), list(orig_data2))]
165        )
166        self.assert_series_equal(result, expected)
167
168        val = s1.iloc[0]
169        result = s1.combine(val, lambda x1, x2: x1 + x2)
170        expected = pd.Series([a + val for a in list(orig_data1)])
171        self.assert_series_equal(result, expected)
172
173    @pytest.mark.skip(reason="Not Applicable")
174    def test_fillna_length_mismatch(self, data_missing):
175        super().test_fillna_length_mismatch(data_missing)
176
177    def test_searchsorted(self, data_for_sorting):
178        if not data_for_sorting.ordered:
179            raise pytest.skip(reason="searchsorted requires ordered data.")
180
181
182class TestCasting(base.BaseCastingTests):
183    @pytest.mark.parametrize("cls", [Categorical, CategoricalIndex])
184    @pytest.mark.parametrize("values", [[1, np.nan], [Timestamp("2000"), pd.NaT]])
185    def test_cast_nan_to_int(self, cls, values):
186        # GH 28406
187        s = cls(values)
188
189        msg = "Cannot (cast|convert)"
190        with pytest.raises((ValueError, TypeError), match=msg):
191            s.astype(int)
192
193    @pytest.mark.parametrize(
194        "expected",
195        [
196            pd.Series(["2019", "2020"], dtype="datetime64[ns, UTC]"),
197            pd.Series([0, 0], dtype="timedelta64[ns]"),
198            pd.Series([pd.Period("2019"), pd.Period("2020")], dtype="period[A-DEC]"),
199            pd.Series([pd.Interval(0, 1), pd.Interval(1, 2)], dtype="interval"),
200            pd.Series([1, np.nan], dtype="Int64"),
201        ],
202    )
203    def test_cast_category_to_extension_dtype(self, expected):
204        # GH 28668
205        result = expected.astype("category").astype(expected.dtype)
206
207        tm.assert_series_equal(result, expected)
208
209    @pytest.mark.parametrize(
210        "dtype, expected",
211        [
212            (
213                "datetime64[ns]",
214                np.array(["2015-01-01T00:00:00.000000000"], dtype="datetime64[ns]"),
215            ),
216            (
217                "datetime64[ns, MET]",
218                pd.DatetimeIndex(
219                    [Timestamp("2015-01-01 00:00:00+0100", tz="MET")]
220                ).array,
221            ),
222        ],
223    )
224    def test_consistent_casting(self, dtype, expected):
225        # GH 28448
226        result = Categorical("2015-01-01").astype(dtype)
227        assert result == expected
228
229
230class TestArithmeticOps(base.BaseArithmeticOpsTests):
231    def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
232        # frame & scalar
233        op_name = all_arithmetic_operators
234        if op_name != "__rmod__":
235            super().test_arith_frame_with_scalar(data, all_arithmetic_operators)
236        else:
237            pytest.skip("rmod never called when string is first argument")
238
239    def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
240
241        op_name = all_arithmetic_operators
242        if op_name != "__rmod__":
243            super().test_arith_series_with_scalar(data, op_name)
244        else:
245            pytest.skip("rmod never called when string is first argument")
246
247    def test_add_series_with_extension_array(self, data):
248        ser = pd.Series(data)
249        with pytest.raises(TypeError, match="cannot perform|unsupported operand"):
250            ser + data
251
252    def test_divmod_series_array(self):
253        # GH 23287
254        # skipping because it is not implemented
255        pass
256
257    def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
258        return super()._check_divmod_op(s, op, other, exc=TypeError)
259
260
261class TestComparisonOps(base.BaseComparisonOpsTests):
262    def _compare_other(self, s, data, op_name, other):
263        op = self.get_op_from_name(op_name)
264        if op_name == "__eq__":
265            result = op(s, other)
266            expected = s.combine(other, lambda x, y: x == y)
267            assert (result == expected).all()
268
269        elif op_name == "__ne__":
270            result = op(s, other)
271            expected = s.combine(other, lambda x, y: x != y)
272            assert (result == expected).all()
273
274        else:
275            msg = "Unordered Categoricals can only compare equality or not"
276            with pytest.raises(TypeError, match=msg):
277                op(data, other)
278
279    @pytest.mark.parametrize(
280        "categories",
281        [["a", "b"], [0, 1], [Timestamp("2019"), Timestamp("2020")]],
282    )
283    def test_not_equal_with_na(self, categories):
284        # https://github.com/pandas-dev/pandas/issues/32276
285        c1 = Categorical.from_codes([-1, 0], categories=categories)
286        c2 = Categorical.from_codes([0, 1], categories=categories)
287
288        result = c1 != c2
289
290        assert result.all()
291
292
293class TestParsing(base.BaseParsingTests):
294    pass
295