1import doctest 2import math 3import pickle 4import re 5import warnings 6from distutils.version import LooseVersion 7from numbers import Number 8 9import pytest 10 11from pint import Quantity 12from pint.compat import ndarray, np 13 14from ..compat import ( 15 HAS_BABEL, 16 HAS_NUMPY, 17 HAS_NUMPY_ARRAY_FUNCTION, 18 HAS_UNCERTAINTIES, 19 NUMPY_VER, 20) 21 22_number_re = r"([-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?)" 23_q_re = re.compile( 24 r"<Quantity\(" 25 + r"\s*" 26 + r"(?P<magnitude>%s)" % _number_re 27 + r"\s*,\s*" 28 + r"'(?P<unit>.*)'" 29 + r"\s*" 30 + r"\)>" 31) 32 33_sq_re = re.compile( 34 r"\s*" + r"(?P<magnitude>%s)" % _number_re + r"\s" + r"(?P<unit>.*)" 35) 36 37_unit_re = re.compile(r"<Unit\((.*)\)>") 38 39 40class PintOutputChecker(doctest.OutputChecker): 41 def check_output(self, want, got, optionflags): 42 check = super().check_output(want, got, optionflags) 43 if check: 44 return check 45 46 try: 47 if eval(want) == eval(got): 48 return True 49 except Exception: 50 pass 51 52 for regex in (_q_re, _sq_re): 53 try: 54 parsed_got = regex.match(got.replace(r"\\", "")).groupdict() 55 parsed_want = regex.match(want.replace(r"\\", "")).groupdict() 56 57 v1 = float(parsed_got["magnitude"]) 58 v2 = float(parsed_want["magnitude"]) 59 60 if abs(v1 - v2) > abs(v1) / 1000: 61 return False 62 63 if parsed_got["unit"] != parsed_want["unit"]: 64 return False 65 66 return True 67 except Exception: 68 pass 69 70 cnt = 0 71 for regex in (_unit_re,): 72 try: 73 parsed_got, tmp = regex.subn("\1", got) 74 cnt += tmp 75 parsed_want, temp = regex.subn("\1", want) 76 cnt += tmp 77 78 if parsed_got == parsed_want: 79 return True 80 81 except Exception: 82 pass 83 84 if cnt: 85 # If there was any replacement, we try again the previous methods. 86 return self.check_output(parsed_want, parsed_got, optionflags) 87 88 return False 89 90 91def _get_comparable_magnitudes(first, second, msg): 92 if isinstance(first, Quantity) and isinstance(second, Quantity): 93 second = second.to(first) 94 assert first.units == second.units, msg + " Units are not equal." 95 m1, m2 = first.magnitude, second.magnitude 96 elif isinstance(first, Quantity): 97 assert first.dimensionless, msg + " The first is not dimensionless." 98 first = first.to("") 99 m1, m2 = first.magnitude, second 100 elif isinstance(second, Quantity): 101 assert second.dimensionless, msg + " The second is not dimensionless." 102 second = second.to("") 103 m1, m2 = first, second.magnitude 104 else: 105 m1, m2 = first, second 106 107 return m1, m2 108 109 110def assert_quantity_equal(first, second, msg=None): 111 if msg is None: 112 msg = "Comparing %r and %r. " % (first, second) 113 114 m1, m2 = _get_comparable_magnitudes(first, second, msg) 115 msg += " (Converted to %r and %r)" % (m1, m2) 116 117 if isinstance(m1, ndarray) or isinstance(m2, ndarray): 118 np.testing.assert_array_equal(m1, m2, err_msg=msg) 119 elif not isinstance(m1, Number): 120 warnings.warn(RuntimeWarning) 121 return 122 elif not isinstance(m2, Number): 123 warnings.warn(RuntimeWarning) 124 return 125 elif math.isnan(m1): 126 assert math.isnan(m2), msg 127 elif math.isnan(m2): 128 assert math.isnan(m1), msg 129 else: 130 assert m1 == m2, msg 131 132 133def assert_quantity_almost_equal(first, second, rtol=1e-07, atol=0, msg=None): 134 if msg is None: 135 try: 136 msg = "Comparing %r and %r. " % (first, second) 137 except TypeError: 138 try: 139 msg = "Comparing %s and %s. " % (first, second) 140 except Exception: 141 msg = "Comparing" 142 143 m1, m2 = _get_comparable_magnitudes(first, second, msg) 144 msg += " (Converted to %r and %r)" % (m1, m2) 145 146 if isinstance(m1, ndarray) or isinstance(m2, ndarray): 147 np.testing.assert_allclose(m1, m2, rtol=rtol, atol=atol, err_msg=msg) 148 elif not isinstance(m1, Number): 149 warnings.warn(RuntimeWarning) 150 return 151 elif not isinstance(m2, Number): 152 warnings.warn(RuntimeWarning) 153 return 154 elif math.isnan(m1): 155 assert math.isnan(m2), msg 156 elif math.isnan(m2): 157 assert math.isnan(m1), msg 158 elif math.isinf(m1): 159 assert math.isinf(m2), msg 160 elif math.isinf(m2): 161 assert math.isinf(m1), msg 162 else: 163 # Numpy version (don't like because is not symmetric) 164 # assert abs(m1 - m2) <= atol + rtol * abs(m2), msg 165 assert abs(m1 - m2) <= max(rtol * max(abs(m1), abs(m2)), atol), msg 166 167 168requires_numpy = pytest.mark.skipif(not HAS_NUMPY, reason="Requires NumPy") 169requires_not_numpy = pytest.mark.skipif( 170 HAS_NUMPY, reason="Requires NumPy not to be installed." 171) 172 173 174def requires_array_function_protocol(): 175 if not HAS_NUMPY: 176 return pytest.mark.skip("Requires NumPy") 177 return pytest.mark.skipif( 178 not HAS_NUMPY_ARRAY_FUNCTION, 179 reason="Requires __array_function__ protocol to be enabled", 180 ) 181 182 183def requires_not_array_function_protocol(): 184 if not HAS_NUMPY: 185 return pytest.mark.skip("Requires NumPy") 186 return pytest.mark.skipif( 187 HAS_NUMPY_ARRAY_FUNCTION, 188 reason="Requires __array_function__ protocol to be unavailable or disabled", 189 ) 190 191 192def requires_numpy_previous_than(version): 193 if not HAS_NUMPY: 194 return pytest.mark.skip("Requires NumPy") 195 return pytest.mark.skipif( 196 not LooseVersion(NUMPY_VER) < LooseVersion(version), 197 reason="Requires NumPy < %s" % version, 198 ) 199 200 201def requires_numpy_at_least(version): 202 if not HAS_NUMPY: 203 return pytest.mark.skip("Requires NumPy") 204 return pytest.mark.skipif( 205 not LooseVersion(NUMPY_VER) >= LooseVersion(version), 206 reason="Requires NumPy >= %s" % version, 207 ) 208 209 210requires_babel = pytest.mark.skipif( 211 not HAS_BABEL, reason="Requires Babel with units support" 212) 213requires_not_babel = pytest.mark.skipif( 214 HAS_BABEL, reason="Requires Babel not to be installed" 215) 216requires_uncertainties = pytest.mark.skipif( 217 not HAS_UNCERTAINTIES, reason="Requires Uncertainties" 218) 219requires_not_uncertainties = pytest.mark.skipif( 220 HAS_UNCERTAINTIES, reason="Requires Uncertainties not to be installed." 221) 222 223# Parametrization 224 225allprotos = pytest.mark.parametrize( 226 ("protocol",), [(p,) for p in range(pickle.HIGHEST_PROTOCOL + 1)] 227) 228 229check_all_bool = pytest.mark.parametrize("check_all", [False, True]) 230