1import numpy as np
2import pytest
3
4import pandas as pd
5from pandas import CategoricalIndex, Index, IntervalIndex, Timestamp
6import pandas._testing as tm
7
8
9class TestTake:
10    def test_take_fill_value(self):
11        # GH 12631
12
13        # numeric category
14        idx = CategoricalIndex([1, 2, 3], name="xxx")
15        result = idx.take(np.array([1, 0, -1]))
16        expected = CategoricalIndex([2, 1, 3], name="xxx")
17        tm.assert_index_equal(result, expected)
18        tm.assert_categorical_equal(result.values, expected.values)
19
20        # fill_value
21        result = idx.take(np.array([1, 0, -1]), fill_value=True)
22        expected = CategoricalIndex([2, 1, np.nan], categories=[1, 2, 3], name="xxx")
23        tm.assert_index_equal(result, expected)
24        tm.assert_categorical_equal(result.values, expected.values)
25
26        # allow_fill=False
27        result = idx.take(np.array([1, 0, -1]), allow_fill=False, fill_value=True)
28        expected = CategoricalIndex([2, 1, 3], name="xxx")
29        tm.assert_index_equal(result, expected)
30        tm.assert_categorical_equal(result.values, expected.values)
31
32        # object category
33        idx = CategoricalIndex(
34            list("CBA"), categories=list("ABC"), ordered=True, name="xxx"
35        )
36        result = idx.take(np.array([1, 0, -1]))
37        expected = CategoricalIndex(
38            list("BCA"), categories=list("ABC"), ordered=True, name="xxx"
39        )
40        tm.assert_index_equal(result, expected)
41        tm.assert_categorical_equal(result.values, expected.values)
42
43        # fill_value
44        result = idx.take(np.array([1, 0, -1]), fill_value=True)
45        expected = CategoricalIndex(
46            ["B", "C", np.nan], categories=list("ABC"), ordered=True, name="xxx"
47        )
48        tm.assert_index_equal(result, expected)
49        tm.assert_categorical_equal(result.values, expected.values)
50
51        # allow_fill=False
52        result = idx.take(np.array([1, 0, -1]), allow_fill=False, fill_value=True)
53        expected = CategoricalIndex(
54            list("BCA"), categories=list("ABC"), ordered=True, name="xxx"
55        )
56        tm.assert_index_equal(result, expected)
57        tm.assert_categorical_equal(result.values, expected.values)
58
59        msg = (
60            "When allow_fill=True and fill_value is not None, "
61            "all indices must be >= -1"
62        )
63        with pytest.raises(ValueError, match=msg):
64            idx.take(np.array([1, 0, -2]), fill_value=True)
65        with pytest.raises(ValueError, match=msg):
66            idx.take(np.array([1, 0, -5]), fill_value=True)
67
68        msg = "index -5 is out of bounds for (axis 0 with )?size 3"
69        with pytest.raises(IndexError, match=msg):
70            idx.take(np.array([1, -5]))
71
72    def test_take_fill_value_datetime(self):
73
74        # datetime category
75        idx = pd.DatetimeIndex(["2011-01-01", "2011-02-01", "2011-03-01"], name="xxx")
76        idx = CategoricalIndex(idx)
77        result = idx.take(np.array([1, 0, -1]))
78        expected = pd.DatetimeIndex(
79            ["2011-02-01", "2011-01-01", "2011-03-01"], name="xxx"
80        )
81        expected = CategoricalIndex(expected)
82        tm.assert_index_equal(result, expected)
83
84        # fill_value
85        result = idx.take(np.array([1, 0, -1]), fill_value=True)
86        expected = pd.DatetimeIndex(["2011-02-01", "2011-01-01", "NaT"], name="xxx")
87        exp_cats = pd.DatetimeIndex(["2011-01-01", "2011-02-01", "2011-03-01"])
88        expected = CategoricalIndex(expected, categories=exp_cats)
89        tm.assert_index_equal(result, expected)
90
91        # allow_fill=False
92        result = idx.take(np.array([1, 0, -1]), allow_fill=False, fill_value=True)
93        expected = pd.DatetimeIndex(
94            ["2011-02-01", "2011-01-01", "2011-03-01"], name="xxx"
95        )
96        expected = CategoricalIndex(expected)
97        tm.assert_index_equal(result, expected)
98
99        msg = (
100            "When allow_fill=True and fill_value is not None, "
101            "all indices must be >= -1"
102        )
103        with pytest.raises(ValueError, match=msg):
104            idx.take(np.array([1, 0, -2]), fill_value=True)
105        with pytest.raises(ValueError, match=msg):
106            idx.take(np.array([1, 0, -5]), fill_value=True)
107
108        msg = "index -5 is out of bounds for (axis 0 with )?size 3"
109        with pytest.raises(IndexError, match=msg):
110            idx.take(np.array([1, -5]))
111
112    def test_take_invalid_kwargs(self):
113        idx = CategoricalIndex([1, 2, 3], name="foo")
114        indices = [1, 0, -1]
115
116        msg = r"take\(\) got an unexpected keyword argument 'foo'"
117        with pytest.raises(TypeError, match=msg):
118            idx.take(indices, foo=2)
119
120        msg = "the 'out' parameter is not supported"
121        with pytest.raises(ValueError, match=msg):
122            idx.take(indices, out=indices)
123
124        msg = "the 'mode' parameter is not supported"
125        with pytest.raises(ValueError, match=msg):
126            idx.take(indices, mode="clip")
127
128
129class TestGetLoc:
130    def test_get_loc(self):
131        # GH 12531
132        cidx1 = CategoricalIndex(list("abcde"), categories=list("edabc"))
133        idx1 = Index(list("abcde"))
134        assert cidx1.get_loc("a") == idx1.get_loc("a")
135        assert cidx1.get_loc("e") == idx1.get_loc("e")
136
137        for i in [cidx1, idx1]:
138            with pytest.raises(KeyError, match="'NOT-EXIST'"):
139                i.get_loc("NOT-EXIST")
140
141        # non-unique
142        cidx2 = CategoricalIndex(list("aacded"), categories=list("edabc"))
143        idx2 = Index(list("aacded"))
144
145        # results in bool array
146        res = cidx2.get_loc("d")
147        tm.assert_numpy_array_equal(res, idx2.get_loc("d"))
148        tm.assert_numpy_array_equal(
149            res, np.array([False, False, False, True, False, True])
150        )
151        # unique element results in scalar
152        res = cidx2.get_loc("e")
153        assert res == idx2.get_loc("e")
154        assert res == 4
155
156        for i in [cidx2, idx2]:
157            with pytest.raises(KeyError, match="'NOT-EXIST'"):
158                i.get_loc("NOT-EXIST")
159
160        # non-unique, sliceable
161        cidx3 = CategoricalIndex(list("aabbb"), categories=list("abc"))
162        idx3 = Index(list("aabbb"))
163
164        # results in slice
165        res = cidx3.get_loc("a")
166        assert res == idx3.get_loc("a")
167        assert res == slice(0, 2, None)
168
169        res = cidx3.get_loc("b")
170        assert res == idx3.get_loc("b")
171        assert res == slice(2, 5, None)
172
173        for i in [cidx3, idx3]:
174            with pytest.raises(KeyError, match="'c'"):
175                i.get_loc("c")
176
177    def test_get_loc_unique(self):
178        cidx = CategoricalIndex(list("abc"))
179        result = cidx.get_loc("b")
180        assert result == 1
181
182    def test_get_loc_monotonic_nonunique(self):
183        cidx = CategoricalIndex(list("abbc"))
184        result = cidx.get_loc("b")
185        expected = slice(1, 3, None)
186        assert result == expected
187
188    def test_get_loc_nonmonotonic_nonunique(self):
189        cidx = CategoricalIndex(list("abcb"))
190        result = cidx.get_loc("b")
191        expected = np.array([False, True, False, True], dtype=bool)
192        tm.assert_numpy_array_equal(result, expected)
193
194
195class TestGetIndexer:
196    def test_get_indexer_base(self):
197        # Determined by cat ordering.
198        idx = CategoricalIndex(list("cab"), categories=list("cab"))
199        expected = np.arange(len(idx), dtype=np.intp)
200
201        actual = idx.get_indexer(idx)
202        tm.assert_numpy_array_equal(expected, actual)
203
204        with pytest.raises(ValueError, match="Invalid fill method"):
205            idx.get_indexer(idx, method="invalid")
206
207    def test_get_indexer_non_unique(self):
208        np.random.seed(123456789)
209
210        ci = CategoricalIndex(list("aabbca"), categories=list("cab"), ordered=False)
211        oidx = Index(np.array(ci))
212
213        for n in [1, 2, 5, len(ci)]:
214            finder = oidx[np.random.randint(0, len(ci), size=n)]
215            expected = oidx.get_indexer_non_unique(finder)[0]
216
217            actual = ci.get_indexer(finder)
218            tm.assert_numpy_array_equal(expected, actual)
219
220        # see gh-17323
221        #
222        # Even when indexer is equal to the
223        # members in the index, we should
224        # respect duplicates instead of taking
225        # the fast-track path.
226        for finder in [list("aabbca"), list("aababca")]:
227            expected = oidx.get_indexer_non_unique(finder)[0]
228
229            actual = ci.get_indexer(finder)
230            tm.assert_numpy_array_equal(expected, actual)
231
232    def test_get_indexer(self):
233
234        idx1 = CategoricalIndex(list("aabcde"), categories=list("edabc"))
235        idx2 = CategoricalIndex(list("abf"))
236
237        for indexer in [idx2, list("abf"), Index(list("abf"))]:
238            r1 = idx1.get_indexer(idx2)
239            tm.assert_almost_equal(r1, np.array([0, 1, 2, -1], dtype=np.intp))
240
241        msg = "method pad not yet implemented for CategoricalIndex"
242        with pytest.raises(NotImplementedError, match=msg):
243            idx2.get_indexer(idx1, method="pad")
244        msg = "method backfill not yet implemented for CategoricalIndex"
245        with pytest.raises(NotImplementedError, match=msg):
246            idx2.get_indexer(idx1, method="backfill")
247
248        msg = "method nearest not yet implemented for CategoricalIndex"
249        with pytest.raises(NotImplementedError, match=msg):
250            idx2.get_indexer(idx1, method="nearest")
251
252    def test_get_indexer_array(self):
253        arr = np.array(
254            [Timestamp("1999-12-31 00:00:00"), Timestamp("2000-12-31 00:00:00")],
255            dtype=object,
256        )
257        cats = [Timestamp("1999-12-31 00:00:00"), Timestamp("2000-12-31 00:00:00")]
258        ci = CategoricalIndex(cats, categories=cats, ordered=False, dtype="category")
259        result = ci.get_indexer(arr)
260        expected = np.array([0, 1], dtype="intp")
261        tm.assert_numpy_array_equal(result, expected)
262
263    def test_get_indexer_same_categories_same_order(self):
264        ci = CategoricalIndex(["a", "b"], categories=["a", "b"])
265
266        result = ci.get_indexer(CategoricalIndex(["b", "b"], categories=["a", "b"]))
267        expected = np.array([1, 1], dtype="intp")
268        tm.assert_numpy_array_equal(result, expected)
269
270    def test_get_indexer_same_categories_different_order(self):
271        # https://github.com/pandas-dev/pandas/issues/19551
272        ci = CategoricalIndex(["a", "b"], categories=["a", "b"])
273
274        result = ci.get_indexer(CategoricalIndex(["b", "b"], categories=["b", "a"]))
275        expected = np.array([1, 1], dtype="intp")
276        tm.assert_numpy_array_equal(result, expected)
277
278
279class TestWhere:
280    @pytest.mark.parametrize("klass", [list, tuple, np.array, pd.Series])
281    def test_where(self, klass):
282        i = CategoricalIndex(list("aabbca"), categories=list("cab"), ordered=False)
283        cond = [True] * len(i)
284        expected = i
285        result = i.where(klass(cond))
286        tm.assert_index_equal(result, expected)
287
288        cond = [False] + [True] * (len(i) - 1)
289        expected = CategoricalIndex([np.nan] + i[1:].tolist(), categories=i.categories)
290        result = i.where(klass(cond))
291        tm.assert_index_equal(result, expected)
292
293    def test_where_non_categories(self):
294        ci = CategoricalIndex(["a", "b", "c", "d"])
295        mask = np.array([True, False, True, False])
296
297        msg = "Cannot setitem on a Categorical with a new category"
298        with pytest.raises(ValueError, match=msg):
299            ci.where(mask, 2)
300
301        with pytest.raises(ValueError, match=msg):
302            # Test the Categorical method directly
303            ci._data.where(mask, 2)
304
305
306class TestContains:
307    def test_contains(self):
308
309        ci = CategoricalIndex(list("aabbca"), categories=list("cabdef"), ordered=False)
310
311        assert "a" in ci
312        assert "z" not in ci
313        assert "e" not in ci
314        assert np.nan not in ci
315
316        # assert codes NOT in index
317        assert 0 not in ci
318        assert 1 not in ci
319
320    def test_contains_nan(self):
321        ci = CategoricalIndex(list("aabbca") + [np.nan], categories=list("cabdef"))
322        assert np.nan in ci
323
324    @pytest.mark.parametrize("unwrap", [True, False])
325    def test_contains_na_dtype(self, unwrap):
326        dti = pd.date_range("2016-01-01", periods=100).insert(0, pd.NaT)
327        pi = dti.to_period("D")
328        tdi = dti - dti[-1]
329        ci = CategoricalIndex(dti)
330
331        obj = ci
332        if unwrap:
333            obj = ci._data
334
335        assert np.nan in obj
336        assert None in obj
337        assert pd.NaT in obj
338        assert np.datetime64("NaT") in obj
339        assert np.timedelta64("NaT") not in obj
340
341        obj2 = CategoricalIndex(tdi)
342        if unwrap:
343            obj2 = obj2._data
344
345        assert np.nan in obj2
346        assert None in obj2
347        assert pd.NaT in obj2
348        assert np.datetime64("NaT") not in obj2
349        assert np.timedelta64("NaT") in obj2
350
351        obj3 = CategoricalIndex(pi)
352        if unwrap:
353            obj3 = obj3._data
354
355        assert np.nan in obj3
356        assert None in obj3
357        assert pd.NaT in obj3
358        assert np.datetime64("NaT") not in obj3
359        assert np.timedelta64("NaT") not in obj3
360
361    @pytest.mark.parametrize(
362        "item, expected",
363        [
364            (pd.Interval(0, 1), True),
365            (1.5, True),
366            (pd.Interval(0.5, 1.5), False),
367            ("a", False),
368            (Timestamp(1), False),
369            (pd.Timedelta(1), False),
370        ],
371        ids=str,
372    )
373    def test_contains_interval(self, item, expected):
374        # GH 23705
375        ci = CategoricalIndex(IntervalIndex.from_breaks(range(3)))
376        result = item in ci
377        assert result is expected
378
379    def test_contains_list(self):
380        # GH#21729
381        idx = CategoricalIndex([1, 2, 3])
382
383        assert "a" not in idx
384
385        with pytest.raises(TypeError, match="unhashable type"):
386            ["a"] in idx
387
388        with pytest.raises(TypeError, match="unhashable type"):
389            ["a", "b"] in idx
390