1import decimal 2import math 3import operator 4 5import numpy as np 6import pytest 7 8import pandas as pd 9import pandas._testing as tm 10from pandas.tests.extension import base 11 12from .array import DecimalArray, DecimalDtype, make_data, to_decimal 13 14 15@pytest.fixture 16def dtype(): 17 return DecimalDtype() 18 19 20@pytest.fixture 21def data(): 22 return DecimalArray(make_data()) 23 24 25@pytest.fixture 26def data_for_twos(): 27 return DecimalArray([decimal.Decimal(2) for _ in range(100)]) 28 29 30@pytest.fixture 31def data_missing(): 32 return DecimalArray([decimal.Decimal("NaN"), decimal.Decimal(1)]) 33 34 35@pytest.fixture 36def data_for_sorting(): 37 return DecimalArray( 38 [decimal.Decimal("1"), decimal.Decimal("2"), decimal.Decimal("0")] 39 ) 40 41 42@pytest.fixture 43def data_missing_for_sorting(): 44 return DecimalArray( 45 [decimal.Decimal("1"), decimal.Decimal("NaN"), decimal.Decimal("0")] 46 ) 47 48 49@pytest.fixture 50def na_cmp(): 51 return lambda x, y: x.is_nan() and y.is_nan() 52 53 54@pytest.fixture 55def na_value(): 56 return decimal.Decimal("NaN") 57 58 59@pytest.fixture 60def data_for_grouping(): 61 b = decimal.Decimal("1.0") 62 a = decimal.Decimal("0.0") 63 c = decimal.Decimal("2.0") 64 na = decimal.Decimal("NaN") 65 return DecimalArray([b, b, na, na, a, a, b, c]) 66 67 68class BaseDecimal: 69 @classmethod 70 def assert_series_equal(cls, left, right, *args, **kwargs): 71 def convert(x): 72 # need to convert array([Decimal(NaN)], dtype='object') to np.NaN 73 # because Series[object].isnan doesn't recognize decimal(NaN) as 74 # NA. 75 try: 76 return math.isnan(x) 77 except TypeError: 78 return False 79 80 if left.dtype == "object": 81 left_na = left.apply(convert) 82 else: 83 left_na = left.isna() 84 if right.dtype == "object": 85 right_na = right.apply(convert) 86 else: 87 right_na = right.isna() 88 89 tm.assert_series_equal(left_na, right_na) 90 return tm.assert_series_equal(left[~left_na], right[~right_na], *args, **kwargs) 91 92 @classmethod 93 def assert_frame_equal(cls, left, right, *args, **kwargs): 94 # TODO(EA): select_dtypes 95 tm.assert_index_equal( 96 left.columns, 97 right.columns, 98 exact=kwargs.get("check_column_type", "equiv"), 99 check_names=kwargs.get("check_names", True), 100 check_exact=kwargs.get("check_exact", False), 101 check_categorical=kwargs.get("check_categorical", True), 102 obj=f"{kwargs.get('obj', 'DataFrame')}.columns", 103 ) 104 105 decimals = (left.dtypes == "decimal").index 106 107 for col in decimals: 108 cls.assert_series_equal(left[col], right[col], *args, **kwargs) 109 110 left = left.drop(columns=decimals) 111 right = right.drop(columns=decimals) 112 tm.assert_frame_equal(left, right, *args, **kwargs) 113 114 115class TestDtype(BaseDecimal, base.BaseDtypeTests): 116 def test_hashable(self, dtype): 117 pass 118 119 120class TestInterface(BaseDecimal, base.BaseInterfaceTests): 121 pass 122 123 124class TestConstructors(BaseDecimal, base.BaseConstructorsTests): 125 @pytest.mark.skip(reason="not implemented constructor from dtype") 126 def test_from_dtype(self, data): 127 # construct from our dtype & string dtype 128 pass 129 130 131class TestReshaping(BaseDecimal, base.BaseReshapingTests): 132 pass 133 134 135class TestGetitem(BaseDecimal, base.BaseGetitemTests): 136 def test_take_na_value_other_decimal(self): 137 arr = DecimalArray([decimal.Decimal("1.0"), decimal.Decimal("2.0")]) 138 result = arr.take([0, -1], allow_fill=True, fill_value=decimal.Decimal("-1.0")) 139 expected = DecimalArray([decimal.Decimal("1.0"), decimal.Decimal("-1.0")]) 140 self.assert_extension_array_equal(result, expected) 141 142 143class TestMissing(BaseDecimal, base.BaseMissingTests): 144 pass 145 146 147class Reduce: 148 def check_reduce(self, s, op_name, skipna): 149 150 if op_name in ["median", "skew", "kurt"]: 151 msg = r"decimal does not support the .* operation" 152 with pytest.raises(NotImplementedError, match=msg): 153 getattr(s, op_name)(skipna=skipna) 154 155 else: 156 result = getattr(s, op_name)(skipna=skipna) 157 expected = getattr(np.asarray(s), op_name)() 158 tm.assert_almost_equal(result, expected) 159 160 161class TestNumericReduce(Reduce, base.BaseNumericReduceTests): 162 pass 163 164 165class TestBooleanReduce(Reduce, base.BaseBooleanReduceTests): 166 pass 167 168 169class TestMethods(BaseDecimal, base.BaseMethodsTests): 170 @pytest.mark.parametrize("dropna", [True, False]) 171 @pytest.mark.xfail(reason="value_counts not implemented yet.") 172 def test_value_counts(self, all_data, dropna): 173 all_data = all_data[:10] 174 if dropna: 175 other = np.array(all_data[~all_data.isna()]) 176 else: 177 other = all_data 178 179 result = pd.Series(all_data).value_counts(dropna=dropna).sort_index() 180 expected = pd.Series(other).value_counts(dropna=dropna).sort_index() 181 182 tm.assert_series_equal(result, expected) 183 184 @pytest.mark.xfail(reason="value_counts not implemented yet.") 185 def test_value_counts_with_normalize(self, data): 186 return super().test_value_counts_with_normalize(data) 187 188 189class TestCasting(BaseDecimal, base.BaseCastingTests): 190 pass 191 192 193class TestGroupby(BaseDecimal, base.BaseGroupbyTests): 194 @pytest.mark.xfail( 195 reason="needs to correctly define __eq__ to handle nans, xref #27081." 196 ) 197 def test_groupby_apply_identity(self, data_for_grouping): 198 super().test_groupby_apply_identity(data_for_grouping) 199 200 @pytest.mark.xfail(reason="GH#39098: Converts agg result to object") 201 def test_groupby_agg_extension(self, data_for_grouping): 202 super().test_groupby_agg_extension(data_for_grouping) 203 204 205class TestSetitem(BaseDecimal, base.BaseSetitemTests): 206 pass 207 208 209class TestPrinting(BaseDecimal, base.BasePrintingTests): 210 def test_series_repr(self, data): 211 # Overriding this base test to explicitly test that 212 # the custom _formatter is used 213 ser = pd.Series(data) 214 assert data.dtype.name in repr(ser) 215 assert "Decimal: " in repr(ser) 216 217 218# TODO(extension) 219@pytest.mark.xfail( 220 reason=( 221 "raising AssertionError as this is not implemented, though easy enough to do" 222 ) 223) 224def test_series_constructor_coerce_data_to_extension_dtype_raises(): 225 xpr = ( 226 "Cannot cast data to extension dtype 'decimal'. Pass the " 227 "extension array directly." 228 ) 229 with pytest.raises(ValueError, match=xpr): 230 pd.Series([0, 1, 2], dtype=DecimalDtype()) 231 232 233def test_series_constructor_with_dtype(): 234 arr = DecimalArray([decimal.Decimal("10.0")]) 235 result = pd.Series(arr, dtype=DecimalDtype()) 236 expected = pd.Series(arr) 237 tm.assert_series_equal(result, expected) 238 239 result = pd.Series(arr, dtype="int64") 240 expected = pd.Series([10]) 241 tm.assert_series_equal(result, expected) 242 243 244def test_dataframe_constructor_with_dtype(): 245 arr = DecimalArray([decimal.Decimal("10.0")]) 246 247 result = pd.DataFrame({"A": arr}, dtype=DecimalDtype()) 248 expected = pd.DataFrame({"A": arr}) 249 tm.assert_frame_equal(result, expected) 250 251 arr = DecimalArray([decimal.Decimal("10.0")]) 252 result = pd.DataFrame({"A": arr}, dtype="int64") 253 expected = pd.DataFrame({"A": [10]}) 254 tm.assert_frame_equal(result, expected) 255 256 257@pytest.mark.parametrize("frame", [True, False]) 258def test_astype_dispatches(frame): 259 # This is a dtype-specific test that ensures Series[decimal].astype 260 # gets all the way through to ExtensionArray.astype 261 # Designing a reliable smoke test that works for arbitrary data types 262 # is difficult. 263 data = pd.Series(DecimalArray([decimal.Decimal(2)]), name="a") 264 ctx = decimal.Context() 265 ctx.prec = 5 266 267 if frame: 268 data = data.to_frame() 269 270 result = data.astype(DecimalDtype(ctx)) 271 272 if frame: 273 result = result["a"] 274 275 assert result.dtype.context.prec == ctx.prec 276 277 278class TestArithmeticOps(BaseDecimal, base.BaseArithmeticOpsTests): 279 def check_opname(self, s, op_name, other, exc=None): 280 super().check_opname(s, op_name, other, exc=None) 281 282 def test_arith_series_with_array(self, data, all_arithmetic_operators): 283 op_name = all_arithmetic_operators 284 s = pd.Series(data) 285 286 context = decimal.getcontext() 287 divbyzerotrap = context.traps[decimal.DivisionByZero] 288 invalidoptrap = context.traps[decimal.InvalidOperation] 289 context.traps[decimal.DivisionByZero] = 0 290 context.traps[decimal.InvalidOperation] = 0 291 292 # Decimal supports ops with int, but not float 293 other = pd.Series([int(d * 100) for d in data]) 294 self.check_opname(s, op_name, other) 295 296 if "mod" not in op_name: 297 self.check_opname(s, op_name, s * 2) 298 299 self.check_opname(s, op_name, 0) 300 self.check_opname(s, op_name, 5) 301 context.traps[decimal.DivisionByZero] = divbyzerotrap 302 context.traps[decimal.InvalidOperation] = invalidoptrap 303 304 def _check_divmod_op(self, s, op, other, exc=NotImplementedError): 305 # We implement divmod 306 super()._check_divmod_op(s, op, other, exc=None) 307 308 def test_error(self): 309 pass 310 311 312class TestComparisonOps(BaseDecimal, base.BaseComparisonOpsTests): 313 def check_opname(self, s, op_name, other, exc=None): 314 super().check_opname(s, op_name, other, exc=None) 315 316 def _compare_other(self, s, data, op_name, other): 317 self.check_opname(s, op_name, other) 318 319 def test_compare_scalar(self, data, all_compare_operators): 320 op_name = all_compare_operators 321 s = pd.Series(data) 322 self._compare_other(s, data, op_name, 0.5) 323 324 def test_compare_array(self, data, all_compare_operators): 325 op_name = all_compare_operators 326 s = pd.Series(data) 327 328 alter = np.random.choice([-1, 0, 1], len(data)) 329 # Randomly double, halve or keep same value 330 other = pd.Series(data) * [decimal.Decimal(pow(2.0, i)) for i in alter] 331 self._compare_other(s, data, op_name, other) 332 333 334class DecimalArrayWithoutFromSequence(DecimalArray): 335 """Helper class for testing error handling in _from_sequence.""" 336 337 def _from_sequence(cls, scalars, dtype=None, copy=False): 338 raise KeyError("For the test") 339 340 341class DecimalArrayWithoutCoercion(DecimalArrayWithoutFromSequence): 342 @classmethod 343 def _create_arithmetic_method(cls, op): 344 return cls._create_method(op, coerce_to_dtype=False) 345 346 347DecimalArrayWithoutCoercion._add_arithmetic_ops() 348 349 350def test_combine_from_sequence_raises(): 351 # https://github.com/pandas-dev/pandas/issues/22850 352 ser = pd.Series( 353 DecimalArrayWithoutFromSequence( 354 [decimal.Decimal("1.0"), decimal.Decimal("2.0")] 355 ) 356 ) 357 result = ser.combine(ser, operator.add) 358 359 # note: object dtype 360 expected = pd.Series( 361 [decimal.Decimal("2.0"), decimal.Decimal("4.0")], dtype="object" 362 ) 363 tm.assert_series_equal(result, expected) 364 365 366@pytest.mark.parametrize( 367 "class_", [DecimalArrayWithoutFromSequence, DecimalArrayWithoutCoercion] 368) 369def test_scalar_ops_from_sequence_raises(class_): 370 # op(EA, EA) should return an EA, or an ndarray if it's not possible 371 # to return an EA with the return values. 372 arr = class_([decimal.Decimal("1.0"), decimal.Decimal("2.0")]) 373 result = arr + arr 374 expected = np.array( 375 [decimal.Decimal("2.0"), decimal.Decimal("4.0")], dtype="object" 376 ) 377 tm.assert_numpy_array_equal(result, expected) 378 379 380@pytest.mark.parametrize( 381 "reverse, expected_div, expected_mod", 382 [(False, [0, 1, 1, 2], [1, 0, 1, 0]), (True, [2, 1, 0, 0], [0, 0, 2, 2])], 383) 384def test_divmod_array(reverse, expected_div, expected_mod): 385 # https://github.com/pandas-dev/pandas/issues/22930 386 arr = to_decimal([1, 2, 3, 4]) 387 if reverse: 388 div, mod = divmod(2, arr) 389 else: 390 div, mod = divmod(arr, 2) 391 expected_div = to_decimal(expected_div) 392 expected_mod = to_decimal(expected_mod) 393 394 tm.assert_extension_array_equal(div, expected_div) 395 tm.assert_extension_array_equal(mod, expected_mod) 396 397 398def test_ufunc_fallback(data): 399 a = data[:5] 400 s = pd.Series(a, index=range(3, 8)) 401 result = np.abs(s) 402 expected = pd.Series(np.abs(a), index=range(3, 8)) 403 tm.assert_series_equal(result, expected) 404 405 406def test_array_ufunc(): 407 a = to_decimal([1, 2, 3]) 408 result = np.exp(a) 409 expected = to_decimal(np.exp(a._data)) 410 tm.assert_extension_array_equal(result, expected) 411 412 413def test_array_ufunc_series(): 414 a = to_decimal([1, 2, 3]) 415 s = pd.Series(a) 416 result = np.exp(s) 417 expected = pd.Series(to_decimal(np.exp(a._data))) 418 tm.assert_series_equal(result, expected) 419 420 421def test_array_ufunc_series_scalar_other(): 422 # check _HANDLED_TYPES 423 a = to_decimal([1, 2, 3]) 424 s = pd.Series(a) 425 result = np.add(s, decimal.Decimal(1)) 426 expected = pd.Series(np.add(a, decimal.Decimal(1))) 427 tm.assert_series_equal(result, expected) 428 429 430def test_array_ufunc_series_defer(): 431 a = to_decimal([1, 2, 3]) 432 s = pd.Series(a) 433 434 expected = pd.Series(to_decimal([2, 4, 6])) 435 r1 = np.add(s, a) 436 r2 = np.add(a, s) 437 438 tm.assert_series_equal(r1, expected) 439 tm.assert_series_equal(r2, expected) 440 441 442def test_groupby_agg(): 443 # Ensure that the result of agg is inferred to be decimal dtype 444 # https://github.com/pandas-dev/pandas/issues/29141 445 446 data = make_data()[:5] 447 df = pd.DataFrame( 448 {"id1": [0, 0, 0, 1, 1], "id2": [0, 1, 0, 1, 1], "decimals": DecimalArray(data)} 449 ) 450 451 # single key, selected column 452 expected = pd.Series(to_decimal([data[0], data[3]])) 453 result = df.groupby("id1")["decimals"].agg(lambda x: x.iloc[0]) 454 tm.assert_series_equal(result, expected, check_names=False) 455 result = df["decimals"].groupby(df["id1"]).agg(lambda x: x.iloc[0]) 456 tm.assert_series_equal(result, expected, check_names=False) 457 458 # multiple keys, selected column 459 expected = pd.Series( 460 to_decimal([data[0], data[1], data[3]]), 461 index=pd.MultiIndex.from_tuples([(0, 0), (0, 1), (1, 1)]), 462 ) 463 result = df.groupby(["id1", "id2"])["decimals"].agg(lambda x: x.iloc[0]) 464 tm.assert_series_equal(result, expected, check_names=False) 465 result = df["decimals"].groupby([df["id1"], df["id2"]]).agg(lambda x: x.iloc[0]) 466 tm.assert_series_equal(result, expected, check_names=False) 467 468 # multiple columns 469 expected = pd.DataFrame({"id2": [0, 1], "decimals": to_decimal([data[0], data[3]])}) 470 result = df.groupby("id1").agg(lambda x: x.iloc[0]) 471 tm.assert_frame_equal(result, expected, check_names=False) 472 473 474def test_groupby_agg_ea_method(monkeypatch): 475 # Ensure that the result of agg is inferred to be decimal dtype 476 # https://github.com/pandas-dev/pandas/issues/29141 477 478 def DecimalArray__my_sum(self): 479 return np.sum(np.array(self)) 480 481 monkeypatch.setattr(DecimalArray, "my_sum", DecimalArray__my_sum, raising=False) 482 483 data = make_data()[:5] 484 df = pd.DataFrame({"id": [0, 0, 0, 1, 1], "decimals": DecimalArray(data)}) 485 expected = pd.Series(to_decimal([data[0] + data[1] + data[2], data[3] + data[4]])) 486 487 result = df.groupby("id")["decimals"].agg(lambda x: x.values.my_sum()) 488 tm.assert_series_equal(result, expected, check_names=False) 489 s = pd.Series(DecimalArray(data)) 490 result = s.groupby(np.array([0, 0, 0, 1, 1])).agg(lambda x: x.values.my_sum()) 491 tm.assert_series_equal(result, expected, check_names=False) 492 493 494def test_indexing_no_materialize(monkeypatch): 495 # See https://github.com/pandas-dev/pandas/issues/29708 496 # Ensure that indexing operations do not materialize (convert to a numpy 497 # array) the ExtensionArray unnecessary 498 499 def DecimalArray__array__(self, dtype=None): 500 raise Exception("tried to convert a DecimalArray to a numpy array") 501 502 monkeypatch.setattr(DecimalArray, "__array__", DecimalArray__array__, raising=False) 503 504 data = make_data() 505 s = pd.Series(DecimalArray(data)) 506 df = pd.DataFrame({"a": s, "b": range(len(s))}) 507 508 # ensure the following operations do not raise an error 509 s[s > 0.5] 510 df[s > 0.5] 511 s.at[0] 512 df.at[0, "a"] 513 514 515def test_to_numpy_keyword(): 516 # test the extra keyword 517 values = [decimal.Decimal("1.1111"), decimal.Decimal("2.2222")] 518 expected = np.array( 519 [decimal.Decimal("1.11"), decimal.Decimal("2.22")], dtype="object" 520 ) 521 a = pd.array(values, dtype="decimal") 522 result = a.to_numpy(decimals=2) 523 tm.assert_numpy_array_equal(result, expected) 524 525 result = pd.Series(a).to_numpy(decimals=2) 526 tm.assert_numpy_array_equal(result, expected) 527