1import numpy as np
2import pytest
3
4from pandas._libs import iNaT
5
6from pandas.core.dtypes.dtypes import PeriodDtype
7
8import pandas as pd
9from pandas.core.arrays import PeriodArray
10from pandas.tests.extension import base
11
12
13@pytest.fixture
14def dtype():
15    return PeriodDtype(freq="D")
16
17
18@pytest.fixture
19def data(dtype):
20    return PeriodArray(np.arange(1970, 2070), freq=dtype.freq)
21
22
23@pytest.fixture
24def data_for_twos(dtype):
25    return PeriodArray(np.ones(100) * 2, freq=dtype.freq)
26
27
28@pytest.fixture
29def data_for_sorting(dtype):
30    return PeriodArray([2018, 2019, 2017], freq=dtype.freq)
31
32
33@pytest.fixture
34def data_missing(dtype):
35    return PeriodArray([iNaT, 2017], freq=dtype.freq)
36
37
38@pytest.fixture
39def data_missing_for_sorting(dtype):
40    return PeriodArray([2018, iNaT, 2017], freq=dtype.freq)
41
42
43@pytest.fixture
44def data_for_grouping(dtype):
45    B = 2018
46    NA = iNaT
47    A = 2017
48    C = 2019
49    return PeriodArray([B, B, NA, NA, A, A, B, C], freq=dtype.freq)
50
51
52@pytest.fixture
53def na_value():
54    return pd.NaT
55
56
57class BasePeriodTests:
58    pass
59
60
61class TestPeriodDtype(BasePeriodTests, base.BaseDtypeTests):
62    pass
63
64
65class TestConstructors(BasePeriodTests, base.BaseConstructorsTests):
66    pass
67
68
69class TestGetitem(BasePeriodTests, base.BaseGetitemTests):
70    pass
71
72
73class TestMethods(BasePeriodTests, base.BaseMethodsTests):
74    def test_combine_add(self, data_repeated):
75        # Period + Period is not defined.
76        pass
77
78
79class TestInterface(BasePeriodTests, base.BaseInterfaceTests):
80
81    pass
82
83
84class TestArithmeticOps(BasePeriodTests, base.BaseArithmeticOpsTests):
85    implements = {"__sub__", "__rsub__"}
86
87    def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
88        # frame & scalar
89        if all_arithmetic_operators in self.implements:
90            df = pd.DataFrame({"A": data})
91            self.check_opname(df, all_arithmetic_operators, data[0], exc=None)
92        else:
93            # ... but not the rest.
94            super().test_arith_frame_with_scalar(data, all_arithmetic_operators)
95
96    def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
97        # we implement substitution...
98        if all_arithmetic_operators in self.implements:
99            s = pd.Series(data)
100            self.check_opname(s, all_arithmetic_operators, s.iloc[0], exc=None)
101        else:
102            # ... but not the rest.
103            super().test_arith_series_with_scalar(data, all_arithmetic_operators)
104
105    def test_arith_series_with_array(self, data, all_arithmetic_operators):
106        if all_arithmetic_operators in self.implements:
107            s = pd.Series(data)
108            self.check_opname(s, all_arithmetic_operators, s.iloc[0], exc=None)
109        else:
110            # ... but not the rest.
111            super().test_arith_series_with_scalar(data, all_arithmetic_operators)
112
113    def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
114        super()._check_divmod_op(s, op, other, exc=TypeError)
115
116    def test_add_series_with_extension_array(self, data):
117        # we don't implement + for Period
118        s = pd.Series(data)
119        msg = (
120            r"unsupported operand type\(s\) for \+: "
121            r"\'PeriodArray\' and \'PeriodArray\'"
122        )
123        with pytest.raises(TypeError, match=msg):
124            s + data
125
126    def test_error(self):
127        pass
128
129    @pytest.mark.parametrize("box", [pd.Series, pd.DataFrame])
130    def test_direct_arith_with_ndframe_returns_not_implemented(self, data, box):
131        # Override to use __sub__ instead of __add__
132        other = pd.Series(data)
133        if box is pd.DataFrame:
134            other = other.to_frame()
135
136        result = data.__sub__(other)
137        assert result is NotImplemented
138
139
140class TestCasting(BasePeriodTests, base.BaseCastingTests):
141    pass
142
143
144class TestComparisonOps(BasePeriodTests, base.BaseComparisonOpsTests):
145    def _compare_other(self, s, data, op_name, other):
146        # the base test is not appropriate for us. We raise on comparison
147        # with (some) integers, depending on the value.
148        pass
149
150
151class TestMissing(BasePeriodTests, base.BaseMissingTests):
152    pass
153
154
155class TestReshaping(BasePeriodTests, base.BaseReshapingTests):
156    pass
157
158
159class TestSetitem(BasePeriodTests, base.BaseSetitemTests):
160    pass
161
162
163class TestGroupby(BasePeriodTests, base.BaseGroupbyTests):
164    pass
165
166
167class TestPrinting(BasePeriodTests, base.BasePrintingTests):
168    pass
169
170
171class TestParsing(BasePeriodTests, base.BaseParsingTests):
172    @pytest.mark.parametrize("engine", ["c", "python"])
173    def test_EA_types(self, engine, data):
174        super().test_EA_types(engine, data)
175