1from typing import Optional, Type
2
3import pytest
4
5import pandas as pd
6import pandas._testing as tm
7from pandas.core import ops
8
9from .base import BaseExtensionTests
10
11
12class BaseOpsUtil(BaseExtensionTests):
13    def get_op_from_name(self, op_name):
14        return tm.get_op_from_name(op_name)
15
16    def check_opname(self, s, op_name, other, exc=Exception):
17        op = self.get_op_from_name(op_name)
18
19        self._check_op(s, op, other, op_name, exc)
20
21    def _check_op(self, s, op, other, op_name, exc=NotImplementedError):
22        if exc is None:
23            result = op(s, other)
24            if isinstance(s, pd.DataFrame):
25                if len(s.columns) != 1:
26                    raise NotImplementedError
27                expected = s.iloc[:, 0].combine(other, op).to_frame()
28                self.assert_frame_equal(result, expected)
29            else:
30                expected = s.combine(other, op)
31                self.assert_series_equal(result, expected)
32        else:
33            with pytest.raises(exc):
34                op(s, other)
35
36    def _check_divmod_op(self, s, op, other, exc=Exception):
37        # divmod has multiple return values, so check separately
38        if exc is None:
39            result_div, result_mod = op(s, other)
40            if op is divmod:
41                expected_div, expected_mod = s // other, s % other
42            else:
43                expected_div, expected_mod = other // s, other % s
44            self.assert_series_equal(result_div, expected_div)
45            self.assert_series_equal(result_mod, expected_mod)
46        else:
47            with pytest.raises(exc):
48                divmod(s, other)
49
50
51class BaseArithmeticOpsTests(BaseOpsUtil):
52    """
53    Various Series and DataFrame arithmetic ops methods.
54
55    Subclasses supporting various ops should set the class variables
56    to indicate that they support ops of that kind
57
58    * series_scalar_exc = TypeError
59    * frame_scalar_exc = TypeError
60    * series_array_exc = TypeError
61    * divmod_exc = TypeError
62    """
63
64    series_scalar_exc: Optional[Type[TypeError]] = TypeError
65    frame_scalar_exc: Optional[Type[TypeError]] = TypeError
66    series_array_exc: Optional[Type[TypeError]] = TypeError
67    divmod_exc: Optional[Type[TypeError]] = TypeError
68
69    def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
70        # series & scalar
71        op_name = all_arithmetic_operators
72        s = pd.Series(data)
73        self.check_opname(s, op_name, s.iloc[0], exc=self.series_scalar_exc)
74
75    @pytest.mark.xfail(run=False, reason="_reduce needs implementation")
76    def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
77        # frame & scalar
78        op_name = all_arithmetic_operators
79        df = pd.DataFrame({"A": data})
80        self.check_opname(df, op_name, data[0], exc=self.frame_scalar_exc)
81
82    def test_arith_series_with_array(self, data, all_arithmetic_operators):
83        # ndarray & other series
84        op_name = all_arithmetic_operators
85        s = pd.Series(data)
86        self.check_opname(
87            s, op_name, pd.Series([s.iloc[0]] * len(s)), exc=self.series_array_exc
88        )
89
90    def test_divmod(self, data):
91        s = pd.Series(data)
92        self._check_divmod_op(s, divmod, 1, exc=self.divmod_exc)
93        self._check_divmod_op(1, ops.rdivmod, s, exc=self.divmod_exc)
94
95    def test_divmod_series_array(self, data, data_for_twos):
96        s = pd.Series(data)
97        self._check_divmod_op(s, divmod, data)
98
99        other = data_for_twos
100        self._check_divmod_op(other, ops.rdivmod, s)
101
102        other = pd.Series(other)
103        self._check_divmod_op(other, ops.rdivmod, s)
104
105    def test_add_series_with_extension_array(self, data):
106        s = pd.Series(data)
107        result = s + data
108        expected = pd.Series(data + data)
109        self.assert_series_equal(result, expected)
110
111    def test_error(self, data, all_arithmetic_operators):
112        # invalid ops
113        op_name = all_arithmetic_operators
114        with pytest.raises(AttributeError):
115            getattr(data, op_name)
116
117    @pytest.mark.parametrize("box", [pd.Series, pd.DataFrame])
118    def test_direct_arith_with_ndframe_returns_not_implemented(self, data, box):
119        # EAs should return NotImplemented for ops with Series/DataFrame
120        # Pandas takes care of unboxing the series and calling the EA's op.
121        other = pd.Series(data)
122        if box is pd.DataFrame:
123            other = other.to_frame()
124        if hasattr(data, "__add__"):
125            result = data.__add__(other)
126            assert result is NotImplemented
127        else:
128            raise pytest.skip(f"{type(data).__name__} does not implement add")
129
130
131class BaseComparisonOpsTests(BaseOpsUtil):
132    """Various Series and DataFrame comparison ops methods."""
133
134    def _compare_other(self, s, data, op_name, other):
135        op = self.get_op_from_name(op_name)
136        if op_name == "__eq__":
137            assert not op(s, other).all()
138        elif op_name == "__ne__":
139            assert op(s, other).all()
140
141        else:
142
143            # array
144            assert getattr(data, op_name)(other) is NotImplemented
145
146            # series
147            s = pd.Series(data)
148            with pytest.raises(TypeError):
149                op(s, other)
150
151    def test_compare_scalar(self, data, all_compare_operators):
152        op_name = all_compare_operators
153        s = pd.Series(data)
154        self._compare_other(s, data, op_name, 0)
155
156    def test_compare_array(self, data, all_compare_operators):
157        op_name = all_compare_operators
158        s = pd.Series(data)
159        other = pd.Series([data[0]] * len(data))
160        self._compare_other(s, data, op_name, other)
161
162    @pytest.mark.parametrize("box", [pd.Series, pd.DataFrame])
163    def test_direct_arith_with_ndframe_returns_not_implemented(self, data, box):
164        # EAs should return NotImplemented for ops with Series/DataFrame
165        # Pandas takes care of unboxing the series and calling the EA's op.
166        other = pd.Series(data)
167        if box is pd.DataFrame:
168            other = other.to_frame()
169
170        if hasattr(data, "__eq__"):
171            result = data.__eq__(other)
172            assert result is NotImplemented
173        else:
174            raise pytest.skip(f"{type(data).__name__} does not implement __eq__")
175
176        if hasattr(data, "__ne__"):
177            result = data.__ne__(other)
178            assert result is NotImplemented
179        else:
180            raise pytest.skip(f"{type(data).__name__} does not implement __ne__")
181
182
183class BaseUnaryOpsTests(BaseOpsUtil):
184    def test_invert(self, data):
185        s = pd.Series(data, name="name")
186        result = ~s
187        expected = pd.Series(~data, name="name")
188        self.assert_series_equal(result, expected)
189