1from itertools import permutations
2
3import numpy as np
4import pytest
5
6from pandas._libs.interval import IntervalTree
7from pandas.compat import IS64
8
9import pandas._testing as tm
10
11
12def skipif_32bit(param):
13    """
14    Skip parameters in a parametrize on 32bit systems. Specifically used
15    here to skip leaf_size parameters related to GH 23440.
16    """
17    marks = pytest.mark.skipif(not IS64, reason="GH 23440: int type mismatch on 32bit")
18    return pytest.param(param, marks=marks)
19
20
21@pytest.fixture(scope="class", params=["int64", "float64", "uint64"])
22def dtype(request):
23    return request.param
24
25
26@pytest.fixture(params=[skipif_32bit(1), skipif_32bit(2), 10])
27def leaf_size(request):
28    """
29    Fixture to specify IntervalTree leaf_size parameter; to be used with the
30    tree fixture.
31    """
32    return request.param
33
34
35@pytest.fixture(
36    params=[
37        np.arange(5, dtype="int64"),
38        np.arange(5, dtype="uint64"),
39        np.arange(5, dtype="float64"),
40        np.array([0, 1, 2, 3, 4, np.nan], dtype="float64"),
41    ]
42)
43def tree(request, leaf_size):
44    left = request.param
45    return IntervalTree(left, left + 2, leaf_size=leaf_size)
46
47
48class TestIntervalTree:
49    def test_get_indexer(self, tree):
50        result = tree.get_indexer(np.array([1.0, 5.5, 6.5]))
51        expected = np.array([0, 4, -1], dtype="intp")
52        tm.assert_numpy_array_equal(result, expected)
53
54        with pytest.raises(
55            KeyError, match="'indexer does not intersect a unique set of intervals'"
56        ):
57            tree.get_indexer(np.array([3.0]))
58
59    @pytest.mark.parametrize(
60        "dtype, target_value, target_dtype",
61        [("int64", 2 ** 63 + 1, "uint64"), ("uint64", -1, "int64")],
62    )
63    def test_get_indexer_overflow(self, dtype, target_value, target_dtype):
64        left, right = np.array([0, 1], dtype=dtype), np.array([1, 2], dtype=dtype)
65        tree = IntervalTree(left, right)
66
67        result = tree.get_indexer(np.array([target_value], dtype=target_dtype))
68        expected = np.array([-1], dtype="intp")
69        tm.assert_numpy_array_equal(result, expected)
70
71    def test_get_indexer_non_unique(self, tree):
72        indexer, missing = tree.get_indexer_non_unique(np.array([1.0, 2.0, 6.5]))
73
74        result = indexer[:1]
75        expected = np.array([0], dtype="intp")
76        tm.assert_numpy_array_equal(result, expected)
77
78        result = np.sort(indexer[1:3])
79        expected = np.array([0, 1], dtype="intp")
80        tm.assert_numpy_array_equal(result, expected)
81
82        result = np.sort(indexer[3:])
83        expected = np.array([-1], dtype="intp")
84        tm.assert_numpy_array_equal(result, expected)
85
86        result = missing
87        expected = np.array([2], dtype="intp")
88        tm.assert_numpy_array_equal(result, expected)
89
90    @pytest.mark.parametrize(
91        "dtype, target_value, target_dtype",
92        [("int64", 2 ** 63 + 1, "uint64"), ("uint64", -1, "int64")],
93    )
94    def test_get_indexer_non_unique_overflow(self, dtype, target_value, target_dtype):
95        left, right = np.array([0, 2], dtype=dtype), np.array([1, 3], dtype=dtype)
96        tree = IntervalTree(left, right)
97        target = np.array([target_value], dtype=target_dtype)
98
99        result_indexer, result_missing = tree.get_indexer_non_unique(target)
100        expected_indexer = np.array([-1], dtype="intp")
101        tm.assert_numpy_array_equal(result_indexer, expected_indexer)
102
103        expected_missing = np.array([0], dtype="intp")
104        tm.assert_numpy_array_equal(result_missing, expected_missing)
105
106    def test_duplicates(self, dtype):
107        left = np.array([0, 0, 0], dtype=dtype)
108        tree = IntervalTree(left, left + 1)
109
110        with pytest.raises(
111            KeyError, match="'indexer does not intersect a unique set of intervals'"
112        ):
113            tree.get_indexer(np.array([0.5]))
114
115        indexer, missing = tree.get_indexer_non_unique(np.array([0.5]))
116        result = np.sort(indexer)
117        expected = np.array([0, 1, 2], dtype="intp")
118        tm.assert_numpy_array_equal(result, expected)
119
120        result = missing
121        expected = np.array([], dtype="intp")
122        tm.assert_numpy_array_equal(result, expected)
123
124    @pytest.mark.parametrize(
125        "leaf_size", [skipif_32bit(1), skipif_32bit(10), skipif_32bit(100), 10000]
126    )
127    def test_get_indexer_closed(self, closed, leaf_size):
128        x = np.arange(1000, dtype="float64")
129        found = x.astype("intp")
130        not_found = (-1 * np.ones(1000)).astype("intp")
131
132        tree = IntervalTree(x, x + 0.5, closed=closed, leaf_size=leaf_size)
133        tm.assert_numpy_array_equal(found, tree.get_indexer(x + 0.25))
134
135        expected = found if tree.closed_left else not_found
136        tm.assert_numpy_array_equal(expected, tree.get_indexer(x + 0.0))
137
138        expected = found if tree.closed_right else not_found
139        tm.assert_numpy_array_equal(expected, tree.get_indexer(x + 0.5))
140
141    @pytest.mark.parametrize(
142        "left, right, expected",
143        [
144            (np.array([0, 1, 4], dtype="int64"), np.array([2, 3, 5]), True),
145            (np.array([0, 1, 2], dtype="int64"), np.array([5, 4, 3]), True),
146            (np.array([0, 1, np.nan]), np.array([5, 4, np.nan]), True),
147            (np.array([0, 2, 4], dtype="int64"), np.array([1, 3, 5]), False),
148            (np.array([0, 2, np.nan]), np.array([1, 3, np.nan]), False),
149        ],
150    )
151    @pytest.mark.parametrize("order", (list(x) for x in permutations(range(3))))
152    def test_is_overlapping(self, closed, order, left, right, expected):
153        # GH 23309
154        tree = IntervalTree(left[order], right[order], closed=closed)
155        result = tree.is_overlapping
156        assert result is expected
157
158    @pytest.mark.parametrize("order", (list(x) for x in permutations(range(3))))
159    def test_is_overlapping_endpoints(self, closed, order):
160        """shared endpoints are marked as overlapping"""
161        # GH 23309
162        left, right = np.arange(3, dtype="int64"), np.arange(1, 4)
163        tree = IntervalTree(left[order], right[order], closed=closed)
164        result = tree.is_overlapping
165        expected = closed == "both"
166        assert result is expected
167
168    @pytest.mark.parametrize(
169        "left, right",
170        [
171            (np.array([], dtype="int64"), np.array([], dtype="int64")),
172            (np.array([0], dtype="int64"), np.array([1], dtype="int64")),
173            (np.array([np.nan]), np.array([np.nan])),
174            (np.array([np.nan] * 3), np.array([np.nan] * 3)),
175        ],
176    )
177    def test_is_overlapping_trivial(self, closed, left, right):
178        # GH 23309
179        tree = IntervalTree(left, right, closed=closed)
180        assert tree.is_overlapping is False
181
182    @pytest.mark.skipif(not IS64, reason="GH 23440")
183    def test_construction_overflow(self):
184        # GH 25485
185        left, right = np.arange(101, dtype="int64"), [np.iinfo(np.int64).max] * 101
186        tree = IntervalTree(left, right)
187
188        # pivot should be average of left/right medians
189        result = tree.root.pivot
190        expected = (50 + np.iinfo(np.int64).max) / 2
191        assert result == expected
192