1import itertools
2import operator
3from os.path import dirname, join
4
5import numpy as np
6import pandas as pd
7import pytest
8from pandas.core import ops
9from pandas.tests.extension import base
10from pandas.tests.extension.conftest import (  # noqa: F401
11    as_array,
12    as_frame,
13    as_series,
14    fillna_method,
15    groupby_apply_op,
16    use_numpy,
17)
18from pint.errors import DimensionalityError
19from pint.testsuite import QuantityTestCase, helpers
20
21import pint_pandas as ppi
22from pint_pandas import PintArray
23
24ureg = ppi.PintType.ureg
25
26
27@pytest.fixture(params=[True, False])
28def box_in_series(request):
29    """Whether to box the data in a Series"""
30    return request.param
31
32
33@pytest.fixture
34def dtype():
35    return ppi.PintType("pint[meter]")
36
37
38@pytest.fixture
39def data():
40    return ppi.PintArray.from_1darray_quantity(
41        np.arange(start=1.0, stop=101.0) * ureg.nm
42    )
43
44
45@pytest.fixture
46def data_missing():
47    return ppi.PintArray.from_1darray_quantity([np.nan, 1] * ureg.meter)
48
49
50@pytest.fixture
51def data_for_twos():
52    x = [
53        2.0,
54    ] * 100
55    return ppi.PintArray.from_1darray_quantity(x * ureg.meter)
56
57
58@pytest.fixture(params=["data", "data_missing"])
59def all_data(request, data, data_missing):
60    if request.param == "data":
61        return data
62    elif request.param == "data_missing":
63        return data_missing
64
65
66@pytest.fixture
67def data_repeated(data):
68    """Return different versions of data for count times"""
69    # no idea what I'm meant to put here, try just copying from https://github.com/pandas-dev/pandas/blob/master/pandas/tests/extension/integer/test_integer.py
70    def gen(count):
71        for _ in range(count):
72            yield data
73
74    yield gen
75
76
77@pytest.fixture(params=[None, lambda x: x])
78def sort_by_key(request):
79    """
80    Simple fixture for testing keys in sorting methods.
81    Tests None (no key) and the identity key.
82    """
83    return request.param
84
85
86@pytest.fixture
87def data_for_sorting():
88    return ppi.PintArray.from_1darray_quantity([0.3, 10, -50] * ureg.centimeter)
89    # should probably get more sophisticated and do something like
90    # [1 * ureg.meter, 3 * ureg.meter, 10 * ureg.centimeter]
91
92
93@pytest.fixture
94def data_missing_for_sorting():
95    return ppi.PintArray.from_1darray_quantity([4, np.nan, -5] * ureg.centimeter)
96    # should probably get more sophisticated and do something like
97    # [4 * ureg.meter, np.nan, 10 * ureg.centimeter]
98
99
100@pytest.fixture
101def na_cmp():
102    """Binary operator for comparing NA values."""
103    return lambda x, y: bool(np.isnan(x.magnitude)) & bool(np.isnan(y.magnitude))
104
105
106@pytest.fixture
107def na_value():
108    return ppi.PintType("meter").na_value
109
110
111@pytest.fixture
112def data_for_grouping():
113    # should probably get more sophisticated here and use units on all these
114    # quantities
115    a = 1.0
116    b = 2.0 ** 32 + 1
117    c = 2.0 ** 32 + 10
118    return ppi.PintArray.from_1darray_quantity(
119        [b, b, np.nan, np.nan, a, a, b, c] * ureg.m
120    )
121
122
123# === missing from pandas extension docs about what has to be included in tests ===
124# copied from pandas/pandas/conftest.py
125_all_arithmetic_operators = [
126    "__add__",
127    "__radd__",
128    "__sub__",
129    "__rsub__",
130    "__mul__",
131    "__rmul__",
132    "__floordiv__",
133    "__rfloordiv__",
134    "__truediv__",
135    "__rtruediv__",
136    "__pow__",
137    "__rpow__",
138    "__mod__",
139    "__rmod__",
140]
141
142
143@pytest.fixture(params=_all_arithmetic_operators)
144def all_arithmetic_operators(request):
145    """
146    Fixture for dunder names for common arithmetic operations
147    """
148    return request.param
149
150
151@pytest.fixture(params=["__eq__", "__ne__", "__le__", "__lt__", "__ge__", "__gt__"])
152def all_compare_operators(request):
153    """
154    Fixture for dunder names for common compare operations
155
156    * >=
157    * >
158    * ==
159    * !=
160    * <
161    * <=
162    """
163    return request.param
164
165
166# commented functions aren't implemented
167_all_numeric_reductions = [
168    "sum",
169    "max",
170    "min",
171    "mean",
172    # "prod",
173    # "std",
174    # "var",
175    "median",
176    # "kurt",
177    # "skew",
178]
179
180
181@pytest.fixture(params=_all_numeric_reductions)
182def all_numeric_reductions(request):
183    """
184    Fixture for numeric reduction names.
185    """
186    return request.param
187
188
189_all_boolean_reductions = ["all", "any"]
190
191
192@pytest.fixture(params=_all_boolean_reductions)
193def all_boolean_reductions(request):
194    """
195    Fixture for boolean reduction names.
196    """
197    return request.param
198
199
200# =================================================================
201
202
203class TestCasting(base.BaseCastingTests):
204    pass
205
206
207class TestConstructors(base.BaseConstructorsTests):
208    @pytest.mark.xfail(run=True, reason="__iter__ / __len__ issue")
209    def test_series_constructor_no_data_with_index(self, dtype, na_value):
210        result = pd.Series(index=[1, 2, 3], dtype=dtype)
211        expected = pd.Series([na_value] * 3, index=[1, 2, 3], dtype=dtype)
212        self.assert_series_equal(result, expected)
213
214        # GH 33559 - empty index
215        result = pd.Series(index=[], dtype=dtype)
216        expected = pd.Series([], index=pd.Index([], dtype="object"), dtype=dtype)
217        self.assert_series_equal(result, expected)
218
219    @pytest.mark.xfail(run=True, reason="__iter__ / __len__ issue")
220    def test_series_constructor_scalar_na_with_index(self, dtype, na_value):
221        result = pd.Series(na_value, index=[1, 2, 3], dtype=dtype)
222        expected = pd.Series([na_value] * 3, index=[1, 2, 3], dtype=dtype)
223        self.assert_series_equal(result, expected)
224
225    @pytest.mark.xfail(run=True, reason="__iter__ / __len__ issue")
226    def test_series_constructor_scalar_with_index(self, data, dtype):
227        scalar = data[0]
228        result = pd.Series(scalar, index=[1, 2, 3], dtype=dtype)
229        expected = pd.Series([scalar] * 3, index=[1, 2, 3], dtype=dtype)
230        self.assert_series_equal(result, expected)
231
232        result = pd.Series(scalar, index=["foo"], dtype=dtype)
233        expected = pd.Series([scalar], index=["foo"], dtype=dtype)
234        self.assert_series_equal(result, expected)
235
236
237class TestDtype(base.BaseDtypeTests):
238    pass
239
240
241class TestGetitem(base.BaseGetitemTests):
242    def test_getitem_mask_raises(self, data):
243        mask = np.array([True, False])
244        msg = f"Boolean index has wrong length: 2 instead of {len(data)}"
245        with pytest.raises(IndexError, match=msg):
246            data[mask]
247
248        mask = pd.array(mask, dtype="boolean")
249        with pytest.raises(IndexError, match=msg):
250            data[mask]
251
252
253class TestGroupby(base.BaseGroupbyTests):
254    @pytest.mark.xfail(run=True, reason="__iter__ / __len__ issue")
255    def test_groupby_apply_identity(self, data_for_grouping):
256        df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4], "B": data_for_grouping})
257        result = df.groupby("A").B.apply(lambda x: x.array)
258        expected = pd.Series(
259            [
260                df.B.iloc[[0, 1, 6]].array,
261                df.B.iloc[[2, 3]].array,
262                df.B.iloc[[4, 5]].array,
263                df.B.iloc[[7]].array,
264            ],
265            index=pd.Index([1, 2, 3, 4], name="A"),
266            name="B",
267        )
268        self.assert_series_equal(result, expected)
269
270
271class TestInterface(base.BaseInterfaceTests):
272    pass
273
274
275class TestMethods(base.BaseMethodsTests):
276    @pytest.mark.filterwarnings("ignore::pint.UnitStrippedWarning")
277    # See test_setitem_mask_broadcast note
278    @pytest.mark.parametrize("dropna", [True, False])
279    def test_value_counts(self, all_data, dropna):
280        all_data = all_data[:10]
281        if dropna:
282            other = all_data[~all_data.isna()]
283        else:
284            other = all_data
285
286        result = pd.Series(all_data).value_counts(dropna=dropna).sort_index()
287        expected = pd.Series(other).value_counts(dropna=dropna).sort_index()
288
289        self.assert_series_equal(result, expected)
290
291    @pytest.mark.filterwarnings("ignore::pint.UnitStrippedWarning")
292    # See test_setitem_mask_broadcast note
293    @pytest.mark.parametrize("box", [pd.Series, lambda x: x])
294    @pytest.mark.parametrize("method", [lambda x: x.unique(), pd.unique])
295    def test_unique(self, data, box, method):
296        duplicated = box(data._from_sequence([data[0], data[0]]))
297
298        result = method(duplicated)
299
300        assert len(result) == 1
301        assert isinstance(result, type(data))
302        assert result[0] == duplicated[0]
303
304    @pytest.mark.xfail(run=True, reason="__iter__ / __len__ issue")
305    def test_fillna_copy_frame(self, data_missing):
306        arr = data_missing.take([1, 1])
307        df = pd.DataFrame({"A": arr})
308
309        filled_val = df.iloc[0, 0]
310        result = df.fillna(filled_val)
311
312        assert df.A.values is not result.A.values
313
314    @pytest.mark.xfail(run=True, reason="__iter__ / __len__ issue")
315    def test_fillna_copy_series(self, data_missing):
316        arr = data_missing.take([1, 1])
317        ser = pd.Series(arr)
318
319        filled_val = ser[0]
320        result = ser.fillna(filled_val)
321
322        assert ser._values is not result._values
323        assert ser._values is arr
324
325    @pytest.mark.xfail(run=True, reason="__iter__ / __len__ issue")
326    def test_searchsorted(self, data_for_sorting, as_series):  # noqa: F811
327        b, c, a = data_for_sorting
328        arr = type(data_for_sorting)._from_sequence([a, b, c])
329
330        if as_series:
331            arr = pd.Series(arr)
332        assert arr.searchsorted(a) == 0
333        assert arr.searchsorted(a, side="right") == 1
334
335        assert arr.searchsorted(b) == 1
336        assert arr.searchsorted(b, side="right") == 2
337
338        assert arr.searchsorted(c) == 2
339        assert arr.searchsorted(c, side="right") == 3
340
341        result = arr.searchsorted(arr.take([0, 2]))
342        expected = np.array([0, 2], dtype=np.intp)
343
344        self.assert_numpy_array_equal(result, expected)
345
346        # sorter
347        sorter = np.array([1, 2, 0])
348        assert data_for_sorting.searchsorted(a, sorter=sorter) == 0
349
350    @pytest.mark.xfail(run=True, reason="__iter__ / __len__ issue")
351    def test_where_series(self, data, na_value, as_frame):  # noqa: F811
352        assert data[0] != data[1]
353        cls = type(data)
354        a, b = data[:2]
355
356        ser = pd.Series(cls._from_sequence([a, a, b, b], dtype=data.dtype))
357        cond = np.array([True, True, False, False])
358
359        if as_frame:
360            ser = ser.to_frame(name="a")
361            cond = cond.reshape(-1, 1)
362
363        result = ser.where(cond)
364        expected = pd.Series(
365            cls._from_sequence([a, a, na_value, na_value], dtype=data.dtype)
366        )
367
368        if as_frame:
369            expected = expected.to_frame(name="a")
370        self.assert_equal(result, expected)
371
372        # array other
373        cond = np.array([True, False, True, True])
374        other = cls._from_sequence([a, b, a, b], dtype=data.dtype)
375        if as_frame:
376            other = pd.DataFrame({"a": other})
377            cond = pd.DataFrame({"a": cond})
378        result = ser.where(cond, other)
379        expected = pd.Series(cls._from_sequence([a, b, b, b], dtype=data.dtype))
380        if as_frame:
381            expected = expected.to_frame(name="a")
382        self.assert_equal(result, expected)
383
384    @pytest.mark.parametrize("ascending", [True, False])
385    def test_sort_values(self, data_for_sorting, ascending, sort_by_key):
386        ser = pd.Series(data_for_sorting)
387        result = ser.sort_values(ascending=ascending, key=sort_by_key)
388        expected = ser.iloc[[2, 0, 1]]
389        if not ascending:
390            expected = expected[::-1]
391
392        self.assert_series_equal(result, expected)
393
394    @pytest.mark.parametrize("ascending", [True, False])
395    def test_sort_values_missing(
396        self, data_missing_for_sorting, ascending, sort_by_key
397    ):
398        ser = pd.Series(data_missing_for_sorting)
399        result = ser.sort_values(ascending=ascending, key=sort_by_key)
400        if ascending:
401            expected = ser.iloc[[2, 0, 1]]
402        else:
403            expected = ser.iloc[[0, 2, 1]]
404        self.assert_series_equal(result, expected)
405
406
407class TestArithmeticOps(base.BaseArithmeticOpsTests):
408    def check_opname(self, s, op_name, other, exc=None):
409        op = self.get_op_from_name(op_name)
410
411        self._check_op(s, op, other, exc)
412
413    def _check_op(self, s, op, other, exc=None):
414        if exc is None:
415            result = op(s, other)
416            expected = s.combine(other, op)
417            self.assert_series_equal(result, expected)
418        else:
419            with pytest.raises(exc):
420                op(s, other)
421
422    def _check_divmod_op(self, s, op, other, exc=None):
423        # divmod has multiple return values, so check separately
424        if exc is None:
425            result_div, result_mod = op(s, other)
426            if op is divmod:
427                expected_div, expected_mod = s // other, s % other
428            else:
429                expected_div, expected_mod = other // s, other % s
430            self.assert_series_equal(result_div, expected_div)
431            self.assert_series_equal(result_mod, expected_mod)
432        else:
433            with pytest.raises(exc):
434                divmod(s, other)
435
436    def _get_exception(self, data, op_name):
437        if op_name in ["__pow__", "__rpow__"]:
438            return op_name, DimensionalityError
439        else:
440            return op_name, None
441
442    def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
443        # series & scalar
444        op_name, exc = self._get_exception(data, all_arithmetic_operators)
445        s = pd.Series(data)
446        self.check_opname(s, op_name, s.iloc[0], exc=exc)
447
448    @pytest.mark.xfail(run=True, reason="__iter__ / __len__ issue")
449    def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
450        # frame & scalar
451        op_name, exc = self._get_exception(data, all_arithmetic_operators)
452        df = pd.DataFrame({"A": data})
453        self.check_opname(df, op_name, data[0], exc=exc)
454
455    @pytest.mark.xfail(run=True, reason="s.combine does not accept arrays")
456    def test_arith_series_with_array(self, data, all_arithmetic_operators):
457        # ndarray & other series
458        op_name, exc = self._get_exception(data, all_arithmetic_operators)
459        s = pd.Series(data)
460        self.check_opname(s, op_name, data, exc=exc)
461
462    # parameterise this to try divisor not equal to 1
463    def test_divmod(self, data):
464        s = pd.Series(data)
465        self._check_divmod_op(s, divmod, 1 * ureg.Mm)
466        self._check_divmod_op(1 * ureg.Mm, ops.rdivmod, s)
467
468    @pytest.mark.xfail(run=True, reason="Test is deleted in pd 1.3, pd GH #39386")
469    def test_error(self, data, all_arithmetic_operators):
470        # invalid ops
471
472        op = all_arithmetic_operators
473        s = pd.Series(data)
474        ops = getattr(s, op)
475        opa = getattr(data, op)
476
477        # invalid scalars
478        # TODO: work out how to make this more specific/test for the two
479        #       different possible errors here
480        with pytest.raises(Exception):
481            ops("foo")
482
483        # TODO: work out how to make this more specific/test for the two
484        #       different possible errors here
485        with pytest.raises(Exception):
486            ops(pd.Timestamp("20180101"))
487
488        # invalid array-likes
489        # TODO: work out how to make this more specific/test for the two
490        #       different possible errors here
491        #
492        # This won't always raise exception, eg for foo % 3 m
493        if "mod" not in op:
494            with pytest.raises(Exception):
495                ops(pd.Series("foo", index=s.index))
496
497        # 2d
498        with pytest.raises(KeyError):
499            opa(pd.DataFrame({"A": s}))
500
501        with pytest.raises(ValueError):
502            opa(np.arange(len(s)).reshape(-1, len(s)))
503
504    @pytest.mark.parametrize("box", [pd.Series, pd.DataFrame])
505    def test_direct_arith_with_ndframe_returns_not_implemented(self, data, box):
506        # EAs should return NotImplemented for ops with Series/DataFrame
507        # Pandas takes care of unboxing the series and calling the EA's op.
508        other = pd.Series(data)
509        if box is pd.DataFrame:
510            other = other.to_frame()
511        if hasattr(data, "__add__"):
512            result = data.__add__(other)
513            assert result is NotImplemented
514        else:
515            raise pytest.skip(f"{type(data).__name__} does not implement add")
516
517
518class TestComparisonOps(base.BaseComparisonOpsTests):
519    def _compare_other(self, s, data, op_name, other):
520        op = self.get_op_from_name(op_name)
521
522        result = op(s, other)
523        expected = op(s.values.quantity, other)
524        assert (result == expected).all()
525
526    def test_compare_scalar(self, data, all_compare_operators):
527        op_name = all_compare_operators
528        s = pd.Series(data)
529        other = data[0]
530        self._compare_other(s, data, op_name, other)
531
532    def test_compare_array(self, data, all_compare_operators):
533        # nb this compares an quantity containing array
534        # eg Q_([1,2],"m")
535        op_name = all_compare_operators
536        s = pd.Series(data)
537        other = data
538        self._compare_other(s, data, op_name, other)
539
540    @pytest.mark.parametrize("box", [pd.Series, pd.DataFrame])
541    def test_direct_arith_with_ndframe_returns_not_implemented(self, data, box):
542        # EAs should return NotImplemented for ops with Series/DataFrame
543        # Pandas takes care of unboxing the series and calling the EA's op.
544        other = pd.Series(data)
545        if box is pd.DataFrame:
546            other = other.to_frame()
547
548        if hasattr(data, "__eq__"):
549            result = data.__eq__(other)
550            assert result is NotImplemented
551        else:
552            raise pytest.skip(f"{type(data).__name__} does not implement __eq__")
553
554        if hasattr(data, "__ne__"):
555            result = data.__ne__(other)
556            assert result is NotImplemented
557        else:
558            raise pytest.skip(f"{type(data).__name__} does not implement __ne__")
559
560
561class TestOpsUtil(base.BaseOpsUtil):
562    pass
563
564
565class TestParsing(base.BaseParsingTests):
566    pass
567
568
569class TestPrinting(base.BasePrintingTests):
570    pass
571
572
573class TestMissing(base.BaseMissingTests):
574    @pytest.mark.xfail(run=True, reason="__iter__ / __len__ issue")
575    def test_fillna_scalar(self, data_missing):
576        valid = data_missing[1]
577        result = data_missing.fillna(valid)
578        expected = data_missing.fillna(valid)
579        self.assert_extension_array_equal(result, expected)
580
581    @pytest.mark.xfail(run=True, reason="__iter__ / __len__ issue")
582    def test_fillna_series(self, data_missing):
583        fill_value = data_missing[1]
584        ser = pd.Series(data_missing)
585
586        result = ser.fillna(fill_value)
587        expected = pd.Series(
588            data_missing._from_sequence(
589                [fill_value, fill_value], dtype=data_missing.dtype
590            )
591        )
592        self.assert_series_equal(result, expected)
593
594        # Fill with a series
595        result = ser.fillna(expected)
596        self.assert_series_equal(result, expected)
597
598        # Fill with a series not affecting the missing values
599        result = ser.fillna(ser)
600        self.assert_series_equal(result, ser)
601
602    @pytest.mark.xfail(run=True, reason="__iter__ / __len__ issue")
603    def test_fillna_frame(self, data_missing):
604        fill_value = data_missing[1]
605
606        result = pd.DataFrame({"A": data_missing, "B": [1, 2]}).fillna(fill_value)
607
608        expected = pd.DataFrame(
609            {
610                "A": data_missing._from_sequence(
611                    [fill_value, fill_value], dtype=data_missing.dtype
612                ),
613                "B": [1, 2],
614            }
615        )
616        self.assert_series_equal(result, expected)
617
618
619class TestNumericReduce(base.BaseNumericReduceTests):
620    def check_reduce(self, s, op_name, skipna):
621        result = getattr(s, op_name)(skipna=skipna)
622        expected_m = getattr(pd.Series(s.values.quantity._magnitude), op_name)(
623            skipna=skipna
624        )
625        expected_u = s.values.quantity.units
626        expected = ureg.Quantity(expected_m, expected_u)
627        assert result == expected
628
629
630class TestBooleanReduce(base.BaseBooleanReduceTests):
631    def check_reduce(self, s, op_name, skipna):
632        result = getattr(s, op_name)(skipna=skipna)
633        expected = getattr(pd.Series(s.values.quantity._magnitude), op_name)(
634            skipna=skipna
635        )
636        assert result == expected
637
638
639class TestReshaping(base.BaseReshapingTests):
640    @pytest.mark.xfail(run=True, reason="__iter__ / __len__ issue")
641    @pytest.mark.parametrize("obj", ["series", "frame"])
642    def test_unstack(self, data, index, obj):
643        data = data[: len(index)]
644        if obj == "series":
645            ser = pd.Series(data, index=index)
646        else:
647            ser = pd.DataFrame({"A": data, "B": data}, index=index)
648
649        n = index.nlevels
650        levels = list(range(n))
651        # [0, 1, 2]
652        # [(0,), (1,), (2,), (0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)]
653        combinations = itertools.chain.from_iterable(
654            itertools.permutations(levels, i) for i in range(1, n)
655        )
656
657        for level in combinations:
658            result = ser.unstack(level=level)
659            assert all(
660                isinstance(result[col].array, type(data)) for col in result.columns
661            )
662
663            if obj == "series":
664                # We should get the same result with to_frame+unstack+droplevel
665                df = ser.to_frame()
666
667                alt = df.unstack(level=level).droplevel(0, axis=1)
668                self.assert_frame_equal(result, alt)
669
670            expected = ser.astype(object).unstack(level=level)
671            result = result.astype(object)
672
673            self.assert_frame_equal(result, expected)
674
675
676class TestSetitem(base.BaseSetitemTests):
677    @pytest.mark.parametrize("setter", ["loc", None])
678    @pytest.mark.filterwarnings("ignore::pint.UnitStrippedWarning")
679    # Pandas performs a hasattr(__array__), which triggers the warning
680    # Debugging it does not pass through a PintArray, so
681    # I think this needs changing in pint quantity
682    # eg s[[True]*len(s)]=Q_(1,"m")
683    @pytest.mark.xfail(run=True, reason="__iter__ / __len__ issue")
684    def test_setitem_mask_broadcast(self, data, setter):
685        ser = pd.Series(data)
686        mask = np.zeros(len(data), dtype=bool)
687        mask[:2] = True
688
689        if setter:  # loc
690            target = getattr(ser, setter)
691        else:  # __setitem__
692            target = ser
693
694        operator.setitem(target, mask, data[10])
695        assert ser[0] == data[10]
696        assert ser[1] == data[10]
697
698    @pytest.mark.xfail(run=True, reason="__iter__ / __len__ issue")
699    def test_setitem_sequence_broadcasts(self, data, box_in_series):
700        if box_in_series:
701            data = pd.Series(data)
702        data[[0, 1]] = data[2]
703        assert data[0] == data[2]
704        assert data[1] == data[2]
705
706    @pytest.mark.xfail(run=True, reason="__iter__ / __len__ issue")
707    @pytest.mark.parametrize(
708        "idx",
709        [[0, 1, 2], pd.array([0, 1, 2], dtype="Int64"), np.array([0, 1, 2])],
710        ids=["list", "integer-array", "numpy-array"],
711    )
712    def test_setitem_integer_array(self, data, idx, box_in_series):
713        arr = data[:5].copy()
714        expected = data.take([0, 0, 0, 3, 4])
715
716        if box_in_series:
717            arr = pd.Series(arr)
718            expected = pd.Series(expected)
719
720        arr[idx] = arr[0]
721        self.assert_equal(arr, expected)
722
723    @pytest.mark.xfail(run=True, reason="__iter__ / __len__ issue")
724    def test_setitem_slice(self, data, box_in_series):
725        arr = data[:5].copy()
726        expected = data.take([0, 0, 0, 3, 4])
727        if box_in_series:
728            arr = pd.Series(arr)
729            expected = pd.Series(expected)
730
731        arr[:3] = data[0]
732        self.assert_equal(arr, expected)
733
734    @pytest.mark.xfail(run=True, reason="__iter__ / __len__ issue")
735    def test_setitem_loc_iloc_slice(self, data):
736        arr = data[:5].copy()
737        s = pd.Series(arr, index=["a", "b", "c", "d", "e"])
738        expected = pd.Series(data.take([0, 0, 0, 3, 4]), index=s.index)
739
740        result = s.copy()
741        result.iloc[:3] = data[0]
742        self.assert_equal(result, expected)
743
744        result = s.copy()
745        result.loc[:"c"] = data[0]
746        self.assert_equal(result, expected)
747
748
749class TestOffsetUnits(object):
750    @pytest.mark.xfail(run=True, reason="TODO untested issue that was fixed")
751    def test_offset_concat(self):
752        q_a = ureg.Quantity(np.arange(5), ureg.Unit("degC"))
753        q_b = ureg.Quantity(np.arange(6), ureg.Unit("degC"))
754
755        a = pd.Series(PintArray(q_a))
756        b = pd.Series(PintArray(q_b))
757
758        result = pd.concat([a, b], axis=1)
759        expected = pd.Series(PintArray(np.concatenate([q_b, q_b]), dtype="pint[degC]"))
760        self.assert_equal(result, expected)
761
762
763# would be ideal to just test all of this by running the example notebook
764# but this isn't a discussion we've had yet
765
766
767class TestUserInterface(object):
768    def test_get_underlying_data(self, data):
769        ser = pd.Series(data)
770        # this first test creates an array of bool (which is desired, eg for indexing)
771        assert all(ser.values == data)
772        assert ser.values[23] == data[23]
773
774    def test_arithmetic(self, data):
775        ser = pd.Series(data)
776        ser2 = ser + ser
777        assert all(ser2.values == 2 * data)
778
779    def test_initialisation(self, data):
780        # fails with plain array
781        # works with PintArray
782        df = pd.DataFrame(
783            {
784                "length": pd.Series([2.0, 3.0], dtype="pint[m]"),
785                "width": PintArray([2.0, 3.0], dtype="pint[m]"),
786                "distance": PintArray([2.0, 3.0], dtype="m"),
787                "height": PintArray([2.0, 3.0], dtype=ureg.m),
788                "depth": PintArray.from_1darray_quantity(
789                    ureg.Quantity([2.0, 3.0], ureg.m)
790                ),
791            }
792        )
793
794        for col in df.columns:
795            assert all(df[col] == df.length)
796
797    def test_df_operations(self):
798        # simply a copy of what's in the notebook
799        df = pd.DataFrame(
800            {
801                "torque": pd.Series([1.0, 2.0, 2.0, 3.0], dtype="pint[lbf ft]"),
802                "angular_velocity": pd.Series([1.0, 2.0, 2.0, 3.0], dtype="pint[rpm]"),
803            }
804        )
805
806        df["power"] = df["torque"] * df["angular_velocity"]
807
808        df.power.values
809        df.power.values.quantity
810        df.angular_velocity.values.data
811
812        df.power.pint.units
813
814        df.power.pint.to("kW").values
815
816        test_csv = join(dirname(__file__), "pandas_test.csv")
817
818        df = pd.read_csv(test_csv, header=[0, 1])
819        df_ = df.pint.quantify(level=-1)
820
821        df_["mech power"] = df_.speed * df_.torque
822        df_["fluid power"] = df_["fuel flow rate"] * df_["rail pressure"]
823
824        df_.pint.dequantify()
825
826        df_["fluid power"] = df_["fluid power"].pint.to("kW")
827        df_["mech power"] = df_["mech power"].pint.to("kW")
828        df_.pint.dequantify()
829
830        df_.pint.to_base_units().pint.dequantify()
831
832
833class TestDataFrameAccessor(object):
834    def test_index_maintained(self):
835        test_csv = join(dirname(__file__), "pandas_test.csv")
836
837        df = pd.read_csv(test_csv, header=[0, 1])
838        df.columns = pd.MultiIndex.from_arrays(
839            [
840                ["Holden", "Holden", "Holden", "Ford", "Ford", "Ford"],
841                [
842                    "speed",
843                    "mech power",
844                    "torque",
845                    "rail pressure",
846                    "fuel flow rate",
847                    "fluid power",
848                ],
849                ["rpm", "kW", "N m", "bar", "l/min", "kW"],
850            ],
851            names=["Car type", "metric", "unit"],
852        )
853        df.index = pd.MultiIndex.from_arrays(
854            [
855                [1, 12, 32, 48],
856                ["Tim", "Tim", "Jane", "Steve"],
857            ],  # noqa E231
858            names=["Measurement number", "Measurer"],
859        )
860
861        expected = df.copy()
862
863        # we expect the result to come back with pint names, not input
864        # names
865        def get_pint_value(in_str):
866            return str(ureg.Quantity(1, in_str).units)
867
868        units_level = [i for i, name in enumerate(df.columns.names) if name == "unit"][
869            0
870        ]
871
872        expected.columns = df.columns.set_levels(
873            df.columns.levels[units_level].map(get_pint_value), level="unit"
874        )
875
876        result = df.pint.quantify(level=-1).pint.dequantify()
877
878        pd.testing.assert_frame_equal(result, expected)
879
880
881class TestSeriesAccessors(object):
882    @pytest.mark.parametrize(
883        "attr",
884        [
885            "debug_used",
886            "default_format",
887            "dimensionality",
888            "dimensionless",
889            "force_ndarray",
890            "shape",
891            "u",
892            "unitless",
893            "units",
894        ],
895    )
896    def test_series_scalar_property_accessors(self, data, attr):
897        s = pd.Series(data)
898        assert getattr(s.pint, attr) == getattr(data.quantity, attr)
899
900    @pytest.mark.parametrize(
901        "attr",
902        [
903            "m",
904            "magnitude",
905            # 'imag', # failing, not sure why
906            # 'real', # failing, not sure why
907        ],
908    )
909    def test_series_property_accessors(self, data, attr):
910        s = pd.Series(data)
911        assert all(getattr(s.pint, attr) == pd.Series(getattr(data.quantity, attr)))
912
913    @pytest.mark.parametrize(
914        "attr_args",
915        [
916            ("check", ({"[length]": 1})),
917            ("compatible_units", ()),
918            # ('format_babel', ()), Needs babel installed?
919            # ('plus_minus', ()), Needs uncertanties
920            # ('to_tuple', ()),
921            ("tolist", ()),
922        ],
923    )
924    def test_series_scalar_method_accessors(self, data, attr_args):
925        attr = attr_args[0]
926        args = attr_args[1]
927        s = pd.Series(data)
928        assert getattr(s.pint, attr)(*args) == getattr(data.quantity, attr)(*args)
929
930    @pytest.mark.parametrize(
931        "attr_args",
932        [
933            ("ito", ("mi",)),
934            ("ito_base_units", ()),
935            ("ito_reduced_units", ()),
936            ("ito_root_units", ()),
937            ("put", (1, 1 * ureg.nm)),
938        ],
939    )
940    def test_series_inplace_method_accessors(self, data, attr_args):
941        attr = attr_args[0]
942        args = attr_args[1]
943        from copy import deepcopy
944
945        s = pd.Series(deepcopy(data))
946        getattr(s.pint, attr)(*args)
947        getattr(data.quantity, attr)(*args)
948        assert all(s.values == data)
949
950    @pytest.mark.parametrize(
951        "attr_args",
952        [
953            ("clip", (10 * ureg.nm, 20 * ureg.nm)),
954            (
955                "from_tuple",
956                (PintArray(np.arange(1, 101), dtype=ureg.m).quantity.to_tuple(),),
957            ),
958            ("m_as", ("mi",)),
959            ("searchsorted", (10 * ureg.nm,)),
960            ("to", ("m")),
961            ("to_base_units", ()),
962            ("to_compact", ()),
963            ("to_reduced_units", ()),
964            ("to_root_units", ()),
965            # ('to_timedelta', ()),
966        ],
967    )
968    def test_series_method_accessors(self, data, attr_args):
969        attr = attr_args[0]
970        args = attr_args[1]
971        s = pd.Series(data)
972        assert all(getattr(s.pint, attr)(*args) == getattr(data.quantity, attr)(*args))
973
974
975arithmetic_ops = [
976    operator.add,
977    operator.sub,
978    operator.mul,
979    operator.truediv,
980    operator.floordiv,
981    operator.pow,
982]
983
984comparative_ops = [
985    operator.eq,
986    operator.le,
987    operator.lt,
988    operator.ge,
989    operator.gt,
990]
991
992
993class TestPintArrayQuantity(QuantityTestCase):
994    FORCE_NDARRAY = True
995
996    def test_pintarray_creation(self):
997        x = ureg.Quantity([1.0, 2.0, 3.0], "m")
998        ys = [
999            PintArray.from_1darray_quantity(x),
1000            PintArray._from_sequence([item for item in x]),
1001        ]
1002        for y in ys:
1003            helpers.assert_quantity_almost_equal(x, y.quantity)
1004
1005    @pytest.mark.filterwarnings("ignore::pint.UnitStrippedWarning")
1006    @pytest.mark.filterwarnings("ignore::RuntimeWarning")
1007    def test_pintarray_operations(self):
1008        # Perform operations with Quantities and PintArrays
1009        # The resulting Quantity and PintArray.Data should be the same
1010        # a op b == c
1011        # warnings ignored here as it these tests are to ensure
1012        # pint array behaviour is the same as quantity
1013        def test_op(a_pint, a_pint_array, b_, coerce=True):
1014            try:
1015                result_pint = op(a_pint, b_)
1016                if coerce:
1017                    # a PintArray is returned from arithmetics, so need the data
1018                    c_pint_array = op(a_pint_array, b_).quantity
1019                else:
1020                    # a boolean array is returned from comparatives
1021                    c_pint_array = op(a_pint_array, b_)
1022
1023                helpers.assert_quantity_almost_equal(result_pint, c_pint_array)
1024
1025            except Exception as caught_exception:
1026                with pytest.raises(type(caught_exception)):
1027                    op(a_pint_array, b)
1028
1029        a_pints = [
1030            ureg.Quantity([3.0, 4.0], "m"),
1031            ureg.Quantity([3.0, 4.0], ""),
1032        ]
1033
1034        a_pint_arrays = [PintArray.from_1darray_quantity(q) for q in a_pints]
1035
1036        bs = [
1037            2,
1038            ureg.Quantity(3, "m"),
1039            [1.0, 3.0],
1040            [3.3, 4.4],
1041            ureg.Quantity([6.0, 6.0], "m"),
1042            ureg.Quantity([7.0, np.nan]),
1043        ]
1044
1045        for a_pint, a_pint_array in zip(a_pints, a_pint_arrays):
1046            for b in bs:
1047                for op in arithmetic_ops:
1048                    test_op(a_pint, a_pint_array, b)
1049                for op in comparative_ops:
1050                    test_op(a_pint, a_pint_array, b, coerce=False)
1051
1052    def test_mismatched_dimensions(self):
1053        x_and_ys = [
1054            (PintArray.from_1darray_quantity(ureg.Quantity([5.0], "m")), [1, 1]),
1055            (
1056                PintArray.from_1darray_quantity(ureg.Quantity([5.0, 5.0, 5.0], "m")),
1057                [1, 1],
1058            ),
1059            (PintArray.from_1darray_quantity(self.Q_([5.0, 5.0], "m")), [1]),
1060        ]
1061        for x, y in x_and_ys:
1062            for op in comparative_ops + arithmetic_ops:
1063                with pytest.raises(ValueError):
1064                    op(x, y)
1065