1import decimal
2import math
3import operator
4
5import numpy as np
6import pytest
7
8import pandas as pd
9import pandas._testing as tm
10from pandas.tests.extension import base
11
12from .array import DecimalArray, DecimalDtype, make_data, to_decimal
13
14
15@pytest.fixture
16def dtype():
17    return DecimalDtype()
18
19
20@pytest.fixture
21def data():
22    return DecimalArray(make_data())
23
24
25@pytest.fixture
26def data_for_twos():
27    return DecimalArray([decimal.Decimal(2) for _ in range(100)])
28
29
30@pytest.fixture
31def data_missing():
32    return DecimalArray([decimal.Decimal("NaN"), decimal.Decimal(1)])
33
34
35@pytest.fixture
36def data_for_sorting():
37    return DecimalArray(
38        [decimal.Decimal("1"), decimal.Decimal("2"), decimal.Decimal("0")]
39    )
40
41
42@pytest.fixture
43def data_missing_for_sorting():
44    return DecimalArray(
45        [decimal.Decimal("1"), decimal.Decimal("NaN"), decimal.Decimal("0")]
46    )
47
48
49@pytest.fixture
50def na_cmp():
51    return lambda x, y: x.is_nan() and y.is_nan()
52
53
54@pytest.fixture
55def na_value():
56    return decimal.Decimal("NaN")
57
58
59@pytest.fixture
60def data_for_grouping():
61    b = decimal.Decimal("1.0")
62    a = decimal.Decimal("0.0")
63    c = decimal.Decimal("2.0")
64    na = decimal.Decimal("NaN")
65    return DecimalArray([b, b, na, na, a, a, b, c])
66
67
68class BaseDecimal:
69    @classmethod
70    def assert_series_equal(cls, left, right, *args, **kwargs):
71        def convert(x):
72            # need to convert array([Decimal(NaN)], dtype='object') to np.NaN
73            # because Series[object].isnan doesn't recognize decimal(NaN) as
74            # NA.
75            try:
76                return math.isnan(x)
77            except TypeError:
78                return False
79
80        if left.dtype == "object":
81            left_na = left.apply(convert)
82        else:
83            left_na = left.isna()
84        if right.dtype == "object":
85            right_na = right.apply(convert)
86        else:
87            right_na = right.isna()
88
89        tm.assert_series_equal(left_na, right_na)
90        return tm.assert_series_equal(left[~left_na], right[~right_na], *args, **kwargs)
91
92    @classmethod
93    def assert_frame_equal(cls, left, right, *args, **kwargs):
94        # TODO(EA): select_dtypes
95        tm.assert_index_equal(
96            left.columns,
97            right.columns,
98            exact=kwargs.get("check_column_type", "equiv"),
99            check_names=kwargs.get("check_names", True),
100            check_exact=kwargs.get("check_exact", False),
101            check_categorical=kwargs.get("check_categorical", True),
102            obj=f"{kwargs.get('obj', 'DataFrame')}.columns",
103        )
104
105        decimals = (left.dtypes == "decimal").index
106
107        for col in decimals:
108            cls.assert_series_equal(left[col], right[col], *args, **kwargs)
109
110        left = left.drop(columns=decimals)
111        right = right.drop(columns=decimals)
112        tm.assert_frame_equal(left, right, *args, **kwargs)
113
114
115class TestDtype(BaseDecimal, base.BaseDtypeTests):
116    def test_hashable(self, dtype):
117        pass
118
119
120class TestInterface(BaseDecimal, base.BaseInterfaceTests):
121    pass
122
123
124class TestConstructors(BaseDecimal, base.BaseConstructorsTests):
125    @pytest.mark.skip(reason="not implemented constructor from dtype")
126    def test_from_dtype(self, data):
127        # construct from our dtype & string dtype
128        pass
129
130
131class TestReshaping(BaseDecimal, base.BaseReshapingTests):
132    pass
133
134
135class TestGetitem(BaseDecimal, base.BaseGetitemTests):
136    def test_take_na_value_other_decimal(self):
137        arr = DecimalArray([decimal.Decimal("1.0"), decimal.Decimal("2.0")])
138        result = arr.take([0, -1], allow_fill=True, fill_value=decimal.Decimal("-1.0"))
139        expected = DecimalArray([decimal.Decimal("1.0"), decimal.Decimal("-1.0")])
140        self.assert_extension_array_equal(result, expected)
141
142
143class TestMissing(BaseDecimal, base.BaseMissingTests):
144    pass
145
146
147class Reduce:
148    def check_reduce(self, s, op_name, skipna):
149
150        if op_name in ["median", "skew", "kurt"]:
151            msg = r"decimal does not support the .* operation"
152            with pytest.raises(NotImplementedError, match=msg):
153                getattr(s, op_name)(skipna=skipna)
154
155        else:
156            result = getattr(s, op_name)(skipna=skipna)
157            expected = getattr(np.asarray(s), op_name)()
158            tm.assert_almost_equal(result, expected)
159
160
161class TestNumericReduce(Reduce, base.BaseNumericReduceTests):
162    pass
163
164
165class TestBooleanReduce(Reduce, base.BaseBooleanReduceTests):
166    pass
167
168
169class TestMethods(BaseDecimal, base.BaseMethodsTests):
170    @pytest.mark.parametrize("dropna", [True, False])
171    @pytest.mark.xfail(reason="value_counts not implemented yet.")
172    def test_value_counts(self, all_data, dropna):
173        all_data = all_data[:10]
174        if dropna:
175            other = np.array(all_data[~all_data.isna()])
176        else:
177            other = all_data
178
179        result = pd.Series(all_data).value_counts(dropna=dropna).sort_index()
180        expected = pd.Series(other).value_counts(dropna=dropna).sort_index()
181
182        tm.assert_series_equal(result, expected)
183
184    @pytest.mark.xfail(reason="value_counts not implemented yet.")
185    def test_value_counts_with_normalize(self, data):
186        return super().test_value_counts_with_normalize(data)
187
188
189class TestCasting(BaseDecimal, base.BaseCastingTests):
190    pass
191
192
193class TestGroupby(BaseDecimal, base.BaseGroupbyTests):
194    @pytest.mark.xfail(
195        reason="needs to correctly define __eq__ to handle nans, xref #27081."
196    )
197    def test_groupby_apply_identity(self, data_for_grouping):
198        super().test_groupby_apply_identity(data_for_grouping)
199
200    @pytest.mark.xfail(reason="GH#39098: Converts agg result to object")
201    def test_groupby_agg_extension(self, data_for_grouping):
202        super().test_groupby_agg_extension(data_for_grouping)
203
204
205class TestSetitem(BaseDecimal, base.BaseSetitemTests):
206    pass
207
208
209class TestPrinting(BaseDecimal, base.BasePrintingTests):
210    def test_series_repr(self, data):
211        # Overriding this base test to explicitly test that
212        # the custom _formatter is used
213        ser = pd.Series(data)
214        assert data.dtype.name in repr(ser)
215        assert "Decimal: " in repr(ser)
216
217
218# TODO(extension)
219@pytest.mark.xfail(
220    reason=(
221        "raising AssertionError as this is not implemented, though easy enough to do"
222    )
223)
224def test_series_constructor_coerce_data_to_extension_dtype_raises():
225    xpr = (
226        "Cannot cast data to extension dtype 'decimal'. Pass the "
227        "extension array directly."
228    )
229    with pytest.raises(ValueError, match=xpr):
230        pd.Series([0, 1, 2], dtype=DecimalDtype())
231
232
233def test_series_constructor_with_dtype():
234    arr = DecimalArray([decimal.Decimal("10.0")])
235    result = pd.Series(arr, dtype=DecimalDtype())
236    expected = pd.Series(arr)
237    tm.assert_series_equal(result, expected)
238
239    result = pd.Series(arr, dtype="int64")
240    expected = pd.Series([10])
241    tm.assert_series_equal(result, expected)
242
243
244def test_dataframe_constructor_with_dtype():
245    arr = DecimalArray([decimal.Decimal("10.0")])
246
247    result = pd.DataFrame({"A": arr}, dtype=DecimalDtype())
248    expected = pd.DataFrame({"A": arr})
249    tm.assert_frame_equal(result, expected)
250
251    arr = DecimalArray([decimal.Decimal("10.0")])
252    result = pd.DataFrame({"A": arr}, dtype="int64")
253    expected = pd.DataFrame({"A": [10]})
254    tm.assert_frame_equal(result, expected)
255
256
257@pytest.mark.parametrize("frame", [True, False])
258def test_astype_dispatches(frame):
259    # This is a dtype-specific test that ensures Series[decimal].astype
260    # gets all the way through to ExtensionArray.astype
261    # Designing a reliable smoke test that works for arbitrary data types
262    # is difficult.
263    data = pd.Series(DecimalArray([decimal.Decimal(2)]), name="a")
264    ctx = decimal.Context()
265    ctx.prec = 5
266
267    if frame:
268        data = data.to_frame()
269
270    result = data.astype(DecimalDtype(ctx))
271
272    if frame:
273        result = result["a"]
274
275    assert result.dtype.context.prec == ctx.prec
276
277
278class TestArithmeticOps(BaseDecimal, base.BaseArithmeticOpsTests):
279    def check_opname(self, s, op_name, other, exc=None):
280        super().check_opname(s, op_name, other, exc=None)
281
282    def test_arith_series_with_array(self, data, all_arithmetic_operators):
283        op_name = all_arithmetic_operators
284        s = pd.Series(data)
285
286        context = decimal.getcontext()
287        divbyzerotrap = context.traps[decimal.DivisionByZero]
288        invalidoptrap = context.traps[decimal.InvalidOperation]
289        context.traps[decimal.DivisionByZero] = 0
290        context.traps[decimal.InvalidOperation] = 0
291
292        # Decimal supports ops with int, but not float
293        other = pd.Series([int(d * 100) for d in data])
294        self.check_opname(s, op_name, other)
295
296        if "mod" not in op_name:
297            self.check_opname(s, op_name, s * 2)
298
299        self.check_opname(s, op_name, 0)
300        self.check_opname(s, op_name, 5)
301        context.traps[decimal.DivisionByZero] = divbyzerotrap
302        context.traps[decimal.InvalidOperation] = invalidoptrap
303
304    def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
305        # We implement divmod
306        super()._check_divmod_op(s, op, other, exc=None)
307
308    def test_error(self):
309        pass
310
311
312class TestComparisonOps(BaseDecimal, base.BaseComparisonOpsTests):
313    def check_opname(self, s, op_name, other, exc=None):
314        super().check_opname(s, op_name, other, exc=None)
315
316    def _compare_other(self, s, data, op_name, other):
317        self.check_opname(s, op_name, other)
318
319    def test_compare_scalar(self, data, all_compare_operators):
320        op_name = all_compare_operators
321        s = pd.Series(data)
322        self._compare_other(s, data, op_name, 0.5)
323
324    def test_compare_array(self, data, all_compare_operators):
325        op_name = all_compare_operators
326        s = pd.Series(data)
327
328        alter = np.random.choice([-1, 0, 1], len(data))
329        # Randomly double, halve or keep same value
330        other = pd.Series(data) * [decimal.Decimal(pow(2.0, i)) for i in alter]
331        self._compare_other(s, data, op_name, other)
332
333
334class DecimalArrayWithoutFromSequence(DecimalArray):
335    """Helper class for testing error handling in _from_sequence."""
336
337    def _from_sequence(cls, scalars, dtype=None, copy=False):
338        raise KeyError("For the test")
339
340
341class DecimalArrayWithoutCoercion(DecimalArrayWithoutFromSequence):
342    @classmethod
343    def _create_arithmetic_method(cls, op):
344        return cls._create_method(op, coerce_to_dtype=False)
345
346
347DecimalArrayWithoutCoercion._add_arithmetic_ops()
348
349
350def test_combine_from_sequence_raises():
351    # https://github.com/pandas-dev/pandas/issues/22850
352    ser = pd.Series(
353        DecimalArrayWithoutFromSequence(
354            [decimal.Decimal("1.0"), decimal.Decimal("2.0")]
355        )
356    )
357    result = ser.combine(ser, operator.add)
358
359    # note: object dtype
360    expected = pd.Series(
361        [decimal.Decimal("2.0"), decimal.Decimal("4.0")], dtype="object"
362    )
363    tm.assert_series_equal(result, expected)
364
365
366@pytest.mark.parametrize(
367    "class_", [DecimalArrayWithoutFromSequence, DecimalArrayWithoutCoercion]
368)
369def test_scalar_ops_from_sequence_raises(class_):
370    # op(EA, EA) should return an EA, or an ndarray if it's not possible
371    # to return an EA with the return values.
372    arr = class_([decimal.Decimal("1.0"), decimal.Decimal("2.0")])
373    result = arr + arr
374    expected = np.array(
375        [decimal.Decimal("2.0"), decimal.Decimal("4.0")], dtype="object"
376    )
377    tm.assert_numpy_array_equal(result, expected)
378
379
380@pytest.mark.parametrize(
381    "reverse, expected_div, expected_mod",
382    [(False, [0, 1, 1, 2], [1, 0, 1, 0]), (True, [2, 1, 0, 0], [0, 0, 2, 2])],
383)
384def test_divmod_array(reverse, expected_div, expected_mod):
385    # https://github.com/pandas-dev/pandas/issues/22930
386    arr = to_decimal([1, 2, 3, 4])
387    if reverse:
388        div, mod = divmod(2, arr)
389    else:
390        div, mod = divmod(arr, 2)
391    expected_div = to_decimal(expected_div)
392    expected_mod = to_decimal(expected_mod)
393
394    tm.assert_extension_array_equal(div, expected_div)
395    tm.assert_extension_array_equal(mod, expected_mod)
396
397
398def test_ufunc_fallback(data):
399    a = data[:5]
400    s = pd.Series(a, index=range(3, 8))
401    result = np.abs(s)
402    expected = pd.Series(np.abs(a), index=range(3, 8))
403    tm.assert_series_equal(result, expected)
404
405
406def test_array_ufunc():
407    a = to_decimal([1, 2, 3])
408    result = np.exp(a)
409    expected = to_decimal(np.exp(a._data))
410    tm.assert_extension_array_equal(result, expected)
411
412
413def test_array_ufunc_series():
414    a = to_decimal([1, 2, 3])
415    s = pd.Series(a)
416    result = np.exp(s)
417    expected = pd.Series(to_decimal(np.exp(a._data)))
418    tm.assert_series_equal(result, expected)
419
420
421def test_array_ufunc_series_scalar_other():
422    # check _HANDLED_TYPES
423    a = to_decimal([1, 2, 3])
424    s = pd.Series(a)
425    result = np.add(s, decimal.Decimal(1))
426    expected = pd.Series(np.add(a, decimal.Decimal(1)))
427    tm.assert_series_equal(result, expected)
428
429
430def test_array_ufunc_series_defer():
431    a = to_decimal([1, 2, 3])
432    s = pd.Series(a)
433
434    expected = pd.Series(to_decimal([2, 4, 6]))
435    r1 = np.add(s, a)
436    r2 = np.add(a, s)
437
438    tm.assert_series_equal(r1, expected)
439    tm.assert_series_equal(r2, expected)
440
441
442def test_groupby_agg():
443    # Ensure that the result of agg is inferred to be decimal dtype
444    # https://github.com/pandas-dev/pandas/issues/29141
445
446    data = make_data()[:5]
447    df = pd.DataFrame(
448        {"id1": [0, 0, 0, 1, 1], "id2": [0, 1, 0, 1, 1], "decimals": DecimalArray(data)}
449    )
450
451    # single key, selected column
452    expected = pd.Series(to_decimal([data[0], data[3]]))
453    result = df.groupby("id1")["decimals"].agg(lambda x: x.iloc[0])
454    tm.assert_series_equal(result, expected, check_names=False)
455    result = df["decimals"].groupby(df["id1"]).agg(lambda x: x.iloc[0])
456    tm.assert_series_equal(result, expected, check_names=False)
457
458    # multiple keys, selected column
459    expected = pd.Series(
460        to_decimal([data[0], data[1], data[3]]),
461        index=pd.MultiIndex.from_tuples([(0, 0), (0, 1), (1, 1)]),
462    )
463    result = df.groupby(["id1", "id2"])["decimals"].agg(lambda x: x.iloc[0])
464    tm.assert_series_equal(result, expected, check_names=False)
465    result = df["decimals"].groupby([df["id1"], df["id2"]]).agg(lambda x: x.iloc[0])
466    tm.assert_series_equal(result, expected, check_names=False)
467
468    # multiple columns
469    expected = pd.DataFrame({"id2": [0, 1], "decimals": to_decimal([data[0], data[3]])})
470    result = df.groupby("id1").agg(lambda x: x.iloc[0])
471    tm.assert_frame_equal(result, expected, check_names=False)
472
473
474def test_groupby_agg_ea_method(monkeypatch):
475    # Ensure that the result of agg is inferred to be decimal dtype
476    # https://github.com/pandas-dev/pandas/issues/29141
477
478    def DecimalArray__my_sum(self):
479        return np.sum(np.array(self))
480
481    monkeypatch.setattr(DecimalArray, "my_sum", DecimalArray__my_sum, raising=False)
482
483    data = make_data()[:5]
484    df = pd.DataFrame({"id": [0, 0, 0, 1, 1], "decimals": DecimalArray(data)})
485    expected = pd.Series(to_decimal([data[0] + data[1] + data[2], data[3] + data[4]]))
486
487    result = df.groupby("id")["decimals"].agg(lambda x: x.values.my_sum())
488    tm.assert_series_equal(result, expected, check_names=False)
489    s = pd.Series(DecimalArray(data))
490    result = s.groupby(np.array([0, 0, 0, 1, 1])).agg(lambda x: x.values.my_sum())
491    tm.assert_series_equal(result, expected, check_names=False)
492
493
494def test_indexing_no_materialize(monkeypatch):
495    # See https://github.com/pandas-dev/pandas/issues/29708
496    # Ensure that indexing operations do not materialize (convert to a numpy
497    # array) the ExtensionArray unnecessary
498
499    def DecimalArray__array__(self, dtype=None):
500        raise Exception("tried to convert a DecimalArray to a numpy array")
501
502    monkeypatch.setattr(DecimalArray, "__array__", DecimalArray__array__, raising=False)
503
504    data = make_data()
505    s = pd.Series(DecimalArray(data))
506    df = pd.DataFrame({"a": s, "b": range(len(s))})
507
508    # ensure the following operations do not raise an error
509    s[s > 0.5]
510    df[s > 0.5]
511    s.at[0]
512    df.at[0, "a"]
513
514
515def test_to_numpy_keyword():
516    # test the extra keyword
517    values = [decimal.Decimal("1.1111"), decimal.Decimal("2.2222")]
518    expected = np.array(
519        [decimal.Decimal("1.11"), decimal.Decimal("2.22")], dtype="object"
520    )
521    a = pd.array(values, dtype="decimal")
522    result = a.to_numpy(decimals=2)
523    tm.assert_numpy_array_equal(result, expected)
524
525    result = pd.Series(a).to_numpy(decimals=2)
526    tm.assert_numpy_array_equal(result, expected)
527