1from itertools import permutations 2 3import numpy as np 4import pytest 5 6from pandas._libs.interval import IntervalTree 7from pandas.compat import IS64 8 9import pandas._testing as tm 10 11 12def skipif_32bit(param): 13 """ 14 Skip parameters in a parametrize on 32bit systems. Specifically used 15 here to skip leaf_size parameters related to GH 23440. 16 """ 17 marks = pytest.mark.skipif(not IS64, reason="GH 23440: int type mismatch on 32bit") 18 return pytest.param(param, marks=marks) 19 20 21@pytest.fixture(scope="class", params=["int64", "float64", "uint64"]) 22def dtype(request): 23 return request.param 24 25 26@pytest.fixture(params=[skipif_32bit(1), skipif_32bit(2), 10]) 27def leaf_size(request): 28 """ 29 Fixture to specify IntervalTree leaf_size parameter; to be used with the 30 tree fixture. 31 """ 32 return request.param 33 34 35@pytest.fixture( 36 params=[ 37 np.arange(5, dtype="int64"), 38 np.arange(5, dtype="uint64"), 39 np.arange(5, dtype="float64"), 40 np.array([0, 1, 2, 3, 4, np.nan], dtype="float64"), 41 ] 42) 43def tree(request, leaf_size): 44 left = request.param 45 return IntervalTree(left, left + 2, leaf_size=leaf_size) 46 47 48class TestIntervalTree: 49 def test_get_indexer(self, tree): 50 result = tree.get_indexer(np.array([1.0, 5.5, 6.5])) 51 expected = np.array([0, 4, -1], dtype="intp") 52 tm.assert_numpy_array_equal(result, expected) 53 54 with pytest.raises( 55 KeyError, match="'indexer does not intersect a unique set of intervals'" 56 ): 57 tree.get_indexer(np.array([3.0])) 58 59 @pytest.mark.parametrize( 60 "dtype, target_value, target_dtype", 61 [("int64", 2 ** 63 + 1, "uint64"), ("uint64", -1, "int64")], 62 ) 63 def test_get_indexer_overflow(self, dtype, target_value, target_dtype): 64 left, right = np.array([0, 1], dtype=dtype), np.array([1, 2], dtype=dtype) 65 tree = IntervalTree(left, right) 66 67 result = tree.get_indexer(np.array([target_value], dtype=target_dtype)) 68 expected = np.array([-1], dtype="intp") 69 tm.assert_numpy_array_equal(result, expected) 70 71 def test_get_indexer_non_unique(self, tree): 72 indexer, missing = tree.get_indexer_non_unique(np.array([1.0, 2.0, 6.5])) 73 74 result = indexer[:1] 75 expected = np.array([0], dtype="intp") 76 tm.assert_numpy_array_equal(result, expected) 77 78 result = np.sort(indexer[1:3]) 79 expected = np.array([0, 1], dtype="intp") 80 tm.assert_numpy_array_equal(result, expected) 81 82 result = np.sort(indexer[3:]) 83 expected = np.array([-1], dtype="intp") 84 tm.assert_numpy_array_equal(result, expected) 85 86 result = missing 87 expected = np.array([2], dtype="intp") 88 tm.assert_numpy_array_equal(result, expected) 89 90 @pytest.mark.parametrize( 91 "dtype, target_value, target_dtype", 92 [("int64", 2 ** 63 + 1, "uint64"), ("uint64", -1, "int64")], 93 ) 94 def test_get_indexer_non_unique_overflow(self, dtype, target_value, target_dtype): 95 left, right = np.array([0, 2], dtype=dtype), np.array([1, 3], dtype=dtype) 96 tree = IntervalTree(left, right) 97 target = np.array([target_value], dtype=target_dtype) 98 99 result_indexer, result_missing = tree.get_indexer_non_unique(target) 100 expected_indexer = np.array([-1], dtype="intp") 101 tm.assert_numpy_array_equal(result_indexer, expected_indexer) 102 103 expected_missing = np.array([0], dtype="intp") 104 tm.assert_numpy_array_equal(result_missing, expected_missing) 105 106 def test_duplicates(self, dtype): 107 left = np.array([0, 0, 0], dtype=dtype) 108 tree = IntervalTree(left, left + 1) 109 110 with pytest.raises( 111 KeyError, match="'indexer does not intersect a unique set of intervals'" 112 ): 113 tree.get_indexer(np.array([0.5])) 114 115 indexer, missing = tree.get_indexer_non_unique(np.array([0.5])) 116 result = np.sort(indexer) 117 expected = np.array([0, 1, 2], dtype="intp") 118 tm.assert_numpy_array_equal(result, expected) 119 120 result = missing 121 expected = np.array([], dtype="intp") 122 tm.assert_numpy_array_equal(result, expected) 123 124 @pytest.mark.parametrize( 125 "leaf_size", [skipif_32bit(1), skipif_32bit(10), skipif_32bit(100), 10000] 126 ) 127 def test_get_indexer_closed(self, closed, leaf_size): 128 x = np.arange(1000, dtype="float64") 129 found = x.astype("intp") 130 not_found = (-1 * np.ones(1000)).astype("intp") 131 132 tree = IntervalTree(x, x + 0.5, closed=closed, leaf_size=leaf_size) 133 tm.assert_numpy_array_equal(found, tree.get_indexer(x + 0.25)) 134 135 expected = found if tree.closed_left else not_found 136 tm.assert_numpy_array_equal(expected, tree.get_indexer(x + 0.0)) 137 138 expected = found if tree.closed_right else not_found 139 tm.assert_numpy_array_equal(expected, tree.get_indexer(x + 0.5)) 140 141 @pytest.mark.parametrize( 142 "left, right, expected", 143 [ 144 (np.array([0, 1, 4], dtype="int64"), np.array([2, 3, 5]), True), 145 (np.array([0, 1, 2], dtype="int64"), np.array([5, 4, 3]), True), 146 (np.array([0, 1, np.nan]), np.array([5, 4, np.nan]), True), 147 (np.array([0, 2, 4], dtype="int64"), np.array([1, 3, 5]), False), 148 (np.array([0, 2, np.nan]), np.array([1, 3, np.nan]), False), 149 ], 150 ) 151 @pytest.mark.parametrize("order", (list(x) for x in permutations(range(3)))) 152 def test_is_overlapping(self, closed, order, left, right, expected): 153 # GH 23309 154 tree = IntervalTree(left[order], right[order], closed=closed) 155 result = tree.is_overlapping 156 assert result is expected 157 158 @pytest.mark.parametrize("order", (list(x) for x in permutations(range(3)))) 159 def test_is_overlapping_endpoints(self, closed, order): 160 """shared endpoints are marked as overlapping""" 161 # GH 23309 162 left, right = np.arange(3, dtype="int64"), np.arange(1, 4) 163 tree = IntervalTree(left[order], right[order], closed=closed) 164 result = tree.is_overlapping 165 expected = closed == "both" 166 assert result is expected 167 168 @pytest.mark.parametrize( 169 "left, right", 170 [ 171 (np.array([], dtype="int64"), np.array([], dtype="int64")), 172 (np.array([0], dtype="int64"), np.array([1], dtype="int64")), 173 (np.array([np.nan]), np.array([np.nan])), 174 (np.array([np.nan] * 3), np.array([np.nan] * 3)), 175 ], 176 ) 177 def test_is_overlapping_trivial(self, closed, left, right): 178 # GH 23309 179 tree = IntervalTree(left, right, closed=closed) 180 assert tree.is_overlapping is False 181 182 @pytest.mark.skipif(not IS64, reason="GH 23440") 183 def test_construction_overflow(self): 184 # GH 25485 185 left, right = np.arange(101, dtype="int64"), [np.iinfo(np.int64).max] * 101 186 tree = IntervalTree(left, right) 187 188 # pivot should be average of left/right medians 189 result = tree.root.pivot 190 expected = (50 + np.iinfo(np.int64).max) / 2 191 assert result == expected 192