1import numpy as np
2import pytest
3
4from pandas import Interval, Period, Timedelta, Timestamp
5import pandas._testing as tm
6import pandas.core.common as com
7
8
9@pytest.fixture
10def interval():
11    return Interval(0, 1)
12
13
14class TestInterval:
15    def test_properties(self, interval):
16        assert interval.closed == "right"
17        assert interval.left == 0
18        assert interval.right == 1
19        assert interval.mid == 0.5
20
21    def test_repr(self, interval):
22        assert repr(interval) == "Interval(0, 1, closed='right')"
23        assert str(interval) == "(0, 1]"
24
25        interval_left = Interval(0, 1, closed="left")
26        assert repr(interval_left) == "Interval(0, 1, closed='left')"
27        assert str(interval_left) == "[0, 1)"
28
29    def test_contains(self, interval):
30        assert 0.5 in interval
31        assert 1 in interval
32        assert 0 not in interval
33
34        msg = "__contains__ not defined for two intervals"
35        with pytest.raises(TypeError, match=msg):
36            interval in interval
37
38        interval_both = Interval(0, 1, closed="both")
39        assert 0 in interval_both
40        assert 1 in interval_both
41
42        interval_neither = Interval(0, 1, closed="neither")
43        assert 0 not in interval_neither
44        assert 0.5 in interval_neither
45        assert 1 not in interval_neither
46
47    def test_equal(self):
48        assert Interval(0, 1) == Interval(0, 1, closed="right")
49        assert Interval(0, 1) != Interval(0, 1, closed="left")
50        assert Interval(0, 1) != 0
51
52    def test_comparison(self):
53        msg = (
54            "'<' not supported between instances of "
55            "'pandas._libs.interval.Interval' and 'int'"
56        )
57        with pytest.raises(TypeError, match=msg):
58            Interval(0, 1) < 2
59
60        assert Interval(0, 1) < Interval(1, 2)
61        assert Interval(0, 1) < Interval(0, 2)
62        assert Interval(0, 1) < Interval(0.5, 1.5)
63        assert Interval(0, 1) <= Interval(0, 1)
64        assert Interval(0, 1) > Interval(-1, 2)
65        assert Interval(0, 1) >= Interval(0, 1)
66
67    def test_hash(self, interval):
68        # should not raise
69        hash(interval)
70
71    @pytest.mark.parametrize(
72        "left, right, expected",
73        [
74            (0, 5, 5),
75            (-2, 5.5, 7.5),
76            (10, 10, 0),
77            (10, np.inf, np.inf),
78            (-np.inf, -5, np.inf),
79            (-np.inf, np.inf, np.inf),
80            (Timedelta("0 days"), Timedelta("5 days"), Timedelta("5 days")),
81            (Timedelta("10 days"), Timedelta("10 days"), Timedelta("0 days")),
82            (Timedelta("1H10min"), Timedelta("5H5min"), Timedelta("3H55min")),
83            (Timedelta("5S"), Timedelta("1H"), Timedelta("59min55S")),
84        ],
85    )
86    def test_length(self, left, right, expected):
87        # GH 18789
88        iv = Interval(left, right)
89        result = iv.length
90        assert result == expected
91
92    @pytest.mark.parametrize(
93        "left, right, expected",
94        [
95            ("2017-01-01", "2017-01-06", "5 days"),
96            ("2017-01-01", "2017-01-01 12:00:00", "12 hours"),
97            ("2017-01-01 12:00", "2017-01-01 12:00:00", "0 days"),
98            ("2017-01-01 12:01", "2017-01-05 17:31:00", "4 days 5 hours 30 min"),
99        ],
100    )
101    @pytest.mark.parametrize("tz", (None, "UTC", "CET", "US/Eastern"))
102    def test_length_timestamp(self, tz, left, right, expected):
103        # GH 18789
104        iv = Interval(Timestamp(left, tz=tz), Timestamp(right, tz=tz))
105        result = iv.length
106        expected = Timedelta(expected)
107        assert result == expected
108
109    @pytest.mark.parametrize(
110        "left, right",
111        [
112            (0, 1),
113            (Timedelta("0 days"), Timedelta("1 day")),
114            (Timestamp("2018-01-01"), Timestamp("2018-01-02")),
115            (
116                Timestamp("2018-01-01", tz="US/Eastern"),
117                Timestamp("2018-01-02", tz="US/Eastern"),
118            ),
119        ],
120    )
121    def test_is_empty(self, left, right, closed):
122        # GH27219
123        # non-empty always return False
124        iv = Interval(left, right, closed)
125        assert iv.is_empty is False
126
127        # same endpoint is empty except when closed='both' (contains one point)
128        iv = Interval(left, left, closed)
129        result = iv.is_empty
130        expected = closed != "both"
131        assert result is expected
132
133    @pytest.mark.parametrize(
134        "left, right",
135        [
136            ("a", "z"),
137            (("a", "b"), ("c", "d")),
138            (list("AB"), list("ab")),
139            (Interval(0, 1), Interval(1, 2)),
140            (Period("2018Q1", freq="Q"), Period("2018Q1", freq="Q")),
141        ],
142    )
143    def test_construct_errors(self, left, right):
144        # GH 23013
145        msg = "Only numeric, Timestamp and Timedelta endpoints are allowed"
146        with pytest.raises(ValueError, match=msg):
147            Interval(left, right)
148
149    def test_math_add(self, closed):
150        interval = Interval(0, 1, closed=closed)
151        expected = Interval(1, 2, closed=closed)
152
153        result = interval + 1
154        assert result == expected
155
156        result = 1 + interval
157        assert result == expected
158
159        result = interval
160        result += 1
161        assert result == expected
162
163        msg = r"unsupported operand type\(s\) for \+"
164        with pytest.raises(TypeError, match=msg):
165            interval + interval
166
167        with pytest.raises(TypeError, match=msg):
168            interval + "foo"
169
170    def test_math_sub(self, closed):
171        interval = Interval(0, 1, closed=closed)
172        expected = Interval(-1, 0, closed=closed)
173
174        result = interval - 1
175        assert result == expected
176
177        result = interval
178        result -= 1
179        assert result == expected
180
181        msg = r"unsupported operand type\(s\) for -"
182        with pytest.raises(TypeError, match=msg):
183            interval - interval
184
185        with pytest.raises(TypeError, match=msg):
186            interval - "foo"
187
188    def test_math_mult(self, closed):
189        interval = Interval(0, 1, closed=closed)
190        expected = Interval(0, 2, closed=closed)
191
192        result = interval * 2
193        assert result == expected
194
195        result = 2 * interval
196        assert result == expected
197
198        result = interval
199        result *= 2
200        assert result == expected
201
202        msg = r"unsupported operand type\(s\) for \*"
203        with pytest.raises(TypeError, match=msg):
204            interval * interval
205
206        msg = r"can\'t multiply sequence by non-int"
207        with pytest.raises(TypeError, match=msg):
208            interval * "foo"
209
210    def test_math_div(self, closed):
211        interval = Interval(0, 1, closed=closed)
212        expected = Interval(0, 0.5, closed=closed)
213
214        result = interval / 2.0
215        assert result == expected
216
217        result = interval
218        result /= 2.0
219        assert result == expected
220
221        msg = r"unsupported operand type\(s\) for /"
222        with pytest.raises(TypeError, match=msg):
223            interval / interval
224
225        with pytest.raises(TypeError, match=msg):
226            interval / "foo"
227
228    def test_math_floordiv(self, closed):
229        interval = Interval(1, 2, closed=closed)
230        expected = Interval(0, 1, closed=closed)
231
232        result = interval // 2
233        assert result == expected
234
235        result = interval
236        result //= 2
237        assert result == expected
238
239        msg = r"unsupported operand type\(s\) for //"
240        with pytest.raises(TypeError, match=msg):
241            interval // interval
242
243        with pytest.raises(TypeError, match=msg):
244            interval // "foo"
245
246    def test_constructor_errors(self):
247        msg = "invalid option for 'closed': foo"
248        with pytest.raises(ValueError, match=msg):
249            Interval(0, 1, closed="foo")
250
251        msg = "left side of interval must be <= right side"
252        with pytest.raises(ValueError, match=msg):
253            Interval(1, 0)
254
255    @pytest.mark.parametrize(
256        "tz_left, tz_right", [(None, "UTC"), ("UTC", None), ("UTC", "US/Eastern")]
257    )
258    def test_constructor_errors_tz(self, tz_left, tz_right):
259        # GH 18538
260        left = Timestamp("2017-01-01", tz=tz_left)
261        right = Timestamp("2017-01-02", tz=tz_right)
262
263        if com.any_none(tz_left, tz_right):
264            error = TypeError
265            msg = "Cannot compare tz-naive and tz-aware timestamps"
266        else:
267            error = ValueError
268            msg = "left and right must have the same time zone"
269        with pytest.raises(error, match=msg):
270            Interval(left, right)
271
272    def test_equality_comparison_broadcasts_over_array(self):
273        # https://github.com/pandas-dev/pandas/issues/35931
274        interval = Interval(0, 1)
275        arr = np.array([interval, interval])
276        result = interval == arr
277        expected = np.array([True, True])
278        tm.assert_numpy_array_equal(result, expected)
279