1import doctest
2import math
3import pickle
4import re
5import warnings
6from distutils.version import LooseVersion
7from numbers import Number
8
9import pytest
10
11from pint import Quantity
12from pint.compat import ndarray, np
13
14from ..compat import (
15    HAS_BABEL,
16    HAS_NUMPY,
17    HAS_NUMPY_ARRAY_FUNCTION,
18    HAS_UNCERTAINTIES,
19    NUMPY_VER,
20)
21
22_number_re = r"([-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?)"
23_q_re = re.compile(
24    r"<Quantity\("
25    + r"\s*"
26    + r"(?P<magnitude>%s)" % _number_re
27    + r"\s*,\s*"
28    + r"'(?P<unit>.*)'"
29    + r"\s*"
30    + r"\)>"
31)
32
33_sq_re = re.compile(
34    r"\s*" + r"(?P<magnitude>%s)" % _number_re + r"\s" + r"(?P<unit>.*)"
35)
36
37_unit_re = re.compile(r"<Unit\((.*)\)>")
38
39
40class PintOutputChecker(doctest.OutputChecker):
41    def check_output(self, want, got, optionflags):
42        check = super().check_output(want, got, optionflags)
43        if check:
44            return check
45
46        try:
47            if eval(want) == eval(got):
48                return True
49        except Exception:
50            pass
51
52        for regex in (_q_re, _sq_re):
53            try:
54                parsed_got = regex.match(got.replace(r"\\", "")).groupdict()
55                parsed_want = regex.match(want.replace(r"\\", "")).groupdict()
56
57                v1 = float(parsed_got["magnitude"])
58                v2 = float(parsed_want["magnitude"])
59
60                if abs(v1 - v2) > abs(v1) / 1000:
61                    return False
62
63                if parsed_got["unit"] != parsed_want["unit"]:
64                    return False
65
66                return True
67            except Exception:
68                pass
69
70        cnt = 0
71        for regex in (_unit_re,):
72            try:
73                parsed_got, tmp = regex.subn("\1", got)
74                cnt += tmp
75                parsed_want, temp = regex.subn("\1", want)
76                cnt += tmp
77
78                if parsed_got == parsed_want:
79                    return True
80
81            except Exception:
82                pass
83
84        if cnt:
85            # If there was any replacement, we try again the previous methods.
86            return self.check_output(parsed_want, parsed_got, optionflags)
87
88        return False
89
90
91def _get_comparable_magnitudes(first, second, msg):
92    if isinstance(first, Quantity) and isinstance(second, Quantity):
93        second = second.to(first)
94        assert first.units == second.units, msg + " Units are not equal."
95        m1, m2 = first.magnitude, second.magnitude
96    elif isinstance(first, Quantity):
97        assert first.dimensionless, msg + " The first is not dimensionless."
98        first = first.to("")
99        m1, m2 = first.magnitude, second
100    elif isinstance(second, Quantity):
101        assert second.dimensionless, msg + " The second is not dimensionless."
102        second = second.to("")
103        m1, m2 = first, second.magnitude
104    else:
105        m1, m2 = first, second
106
107    return m1, m2
108
109
110def assert_quantity_equal(first, second, msg=None):
111    if msg is None:
112        msg = "Comparing %r and %r. " % (first, second)
113
114    m1, m2 = _get_comparable_magnitudes(first, second, msg)
115    msg += " (Converted to %r and %r)" % (m1, m2)
116
117    if isinstance(m1, ndarray) or isinstance(m2, ndarray):
118        np.testing.assert_array_equal(m1, m2, err_msg=msg)
119    elif not isinstance(m1, Number):
120        warnings.warn(RuntimeWarning)
121        return
122    elif not isinstance(m2, Number):
123        warnings.warn(RuntimeWarning)
124        return
125    elif math.isnan(m1):
126        assert math.isnan(m2), msg
127    elif math.isnan(m2):
128        assert math.isnan(m1), msg
129    else:
130        assert m1 == m2, msg
131
132
133def assert_quantity_almost_equal(first, second, rtol=1e-07, atol=0, msg=None):
134    if msg is None:
135        try:
136            msg = "Comparing %r and %r. " % (first, second)
137        except TypeError:
138            try:
139                msg = "Comparing %s and %s. " % (first, second)
140            except Exception:
141                msg = "Comparing"
142
143    m1, m2 = _get_comparable_magnitudes(first, second, msg)
144    msg += " (Converted to %r and %r)" % (m1, m2)
145
146    if isinstance(m1, ndarray) or isinstance(m2, ndarray):
147        np.testing.assert_allclose(m1, m2, rtol=rtol, atol=atol, err_msg=msg)
148    elif not isinstance(m1, Number):
149        warnings.warn(RuntimeWarning)
150        return
151    elif not isinstance(m2, Number):
152        warnings.warn(RuntimeWarning)
153        return
154    elif math.isnan(m1):
155        assert math.isnan(m2), msg
156    elif math.isnan(m2):
157        assert math.isnan(m1), msg
158    elif math.isinf(m1):
159        assert math.isinf(m2), msg
160    elif math.isinf(m2):
161        assert math.isinf(m1), msg
162    else:
163        # Numpy version (don't like because is not symmetric)
164        # assert abs(m1 - m2) <= atol + rtol * abs(m2), msg
165        assert abs(m1 - m2) <= max(rtol * max(abs(m1), abs(m2)), atol), msg
166
167
168requires_numpy = pytest.mark.skipif(not HAS_NUMPY, reason="Requires NumPy")
169requires_not_numpy = pytest.mark.skipif(
170    HAS_NUMPY, reason="Requires NumPy not to be installed."
171)
172
173
174def requires_array_function_protocol():
175    if not HAS_NUMPY:
176        return pytest.mark.skip("Requires NumPy")
177    return pytest.mark.skipif(
178        not HAS_NUMPY_ARRAY_FUNCTION,
179        reason="Requires __array_function__ protocol to be enabled",
180    )
181
182
183def requires_not_array_function_protocol():
184    if not HAS_NUMPY:
185        return pytest.mark.skip("Requires NumPy")
186    return pytest.mark.skipif(
187        HAS_NUMPY_ARRAY_FUNCTION,
188        reason="Requires __array_function__ protocol to be unavailable or disabled",
189    )
190
191
192def requires_numpy_previous_than(version):
193    if not HAS_NUMPY:
194        return pytest.mark.skip("Requires NumPy")
195    return pytest.mark.skipif(
196        not LooseVersion(NUMPY_VER) < LooseVersion(version),
197        reason="Requires NumPy < %s" % version,
198    )
199
200
201def requires_numpy_at_least(version):
202    if not HAS_NUMPY:
203        return pytest.mark.skip("Requires NumPy")
204    return pytest.mark.skipif(
205        not LooseVersion(NUMPY_VER) >= LooseVersion(version),
206        reason="Requires NumPy >= %s" % version,
207    )
208
209
210requires_babel = pytest.mark.skipif(
211    not HAS_BABEL, reason="Requires Babel with units support"
212)
213requires_not_babel = pytest.mark.skipif(
214    HAS_BABEL, reason="Requires Babel not to be installed"
215)
216requires_uncertainties = pytest.mark.skipif(
217    not HAS_UNCERTAINTIES, reason="Requires Uncertainties"
218)
219requires_not_uncertainties = pytest.mark.skipif(
220    HAS_UNCERTAINTIES, reason="Requires Uncertainties not to be installed."
221)
222
223# Parametrization
224
225allprotos = pytest.mark.parametrize(
226    ("protocol",), [(p,) for p in range(pickle.HIGHEST_PROTOCOL + 1)]
227)
228
229check_all_bool = pytest.mark.parametrize("check_all", [False, True])
230