1"""Miscellaneous functions for testing masked arrays and subclasses 2 3:author: Pierre Gerard-Marchant 4:contact: pierregm_at_uga_dot_edu 5:version: $Id: testutils.py 3529 2007-11-13 08:01:14Z jarrod.millman $ 6 7""" 8import operator 9 10import numpy as np 11from numpy import ndarray, float_ 12import numpy.core.umath as umath 13import numpy.testing 14from numpy.testing import ( 15 assert_, assert_allclose, assert_array_almost_equal_nulp, 16 assert_raises, build_err_msg 17 ) 18from .core import mask_or, getmask, masked_array, nomask, masked, filled 19 20__all__masked = [ 21 'almost', 'approx', 'assert_almost_equal', 'assert_array_almost_equal', 22 'assert_array_approx_equal', 'assert_array_compare', 23 'assert_array_equal', 'assert_array_less', 'assert_close', 24 'assert_equal', 'assert_equal_records', 'assert_mask_equal', 25 'assert_not_equal', 'fail_if_array_equal', 26 ] 27 28# Include some normal test functions to avoid breaking other projects who 29# have mistakenly included them from this file. SciPy is one. That is 30# unfortunate, as some of these functions are not intended to work with 31# masked arrays. But there was no way to tell before. 32from unittest import TestCase 33__some__from_testing = [ 34 'TestCase', 'assert_', 'assert_allclose', 'assert_array_almost_equal_nulp', 35 'assert_raises' 36 ] 37 38__all__ = __all__masked + __some__from_testing 39 40 41def approx(a, b, fill_value=True, rtol=1e-5, atol=1e-8): 42 """ 43 Returns true if all components of a and b are equal to given tolerances. 44 45 If fill_value is True, masked values considered equal. Otherwise, 46 masked values are considered unequal. The relative error rtol should 47 be positive and << 1.0 The absolute error atol comes into play for 48 those elements of b that are very small or zero; it says how small a 49 must be also. 50 51 """ 52 m = mask_or(getmask(a), getmask(b)) 53 d1 = filled(a) 54 d2 = filled(b) 55 if d1.dtype.char == "O" or d2.dtype.char == "O": 56 return np.equal(d1, d2).ravel() 57 x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_) 58 y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_) 59 d = np.less_equal(umath.absolute(x - y), atol + rtol * umath.absolute(y)) 60 return d.ravel() 61 62 63def almost(a, b, decimal=6, fill_value=True): 64 """ 65 Returns True if a and b are equal up to decimal places. 66 67 If fill_value is True, masked values considered equal. Otherwise, 68 masked values are considered unequal. 69 70 """ 71 m = mask_or(getmask(a), getmask(b)) 72 d1 = filled(a) 73 d2 = filled(b) 74 if d1.dtype.char == "O" or d2.dtype.char == "O": 75 return np.equal(d1, d2).ravel() 76 x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_) 77 y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_) 78 d = np.around(np.abs(x - y), decimal) <= 10.0 ** (-decimal) 79 return d.ravel() 80 81 82def _assert_equal_on_sequences(actual, desired, err_msg=''): 83 """ 84 Asserts the equality of two non-array sequences. 85 86 """ 87 assert_equal(len(actual), len(desired), err_msg) 88 for k in range(len(desired)): 89 assert_equal(actual[k], desired[k], f'item={k!r}\n{err_msg}') 90 return 91 92 93def assert_equal_records(a, b): 94 """ 95 Asserts that two records are equal. 96 97 Pretty crude for now. 98 99 """ 100 assert_equal(a.dtype, b.dtype) 101 for f in a.dtype.names: 102 (af, bf) = (operator.getitem(a, f), operator.getitem(b, f)) 103 if not (af is masked) and not (bf is masked): 104 assert_equal(operator.getitem(a, f), operator.getitem(b, f)) 105 return 106 107 108def assert_equal(actual, desired, err_msg=''): 109 """ 110 Asserts that two items are equal. 111 112 """ 113 # Case #1: dictionary ..... 114 if isinstance(desired, dict): 115 if not isinstance(actual, dict): 116 raise AssertionError(repr(type(actual))) 117 assert_equal(len(actual), len(desired), err_msg) 118 for k, i in desired.items(): 119 if k not in actual: 120 raise AssertionError(f"{k} not in {actual}") 121 assert_equal(actual[k], desired[k], f'key={k!r}\n{err_msg}') 122 return 123 # Case #2: lists ..... 124 if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)): 125 return _assert_equal_on_sequences(actual, desired, err_msg='') 126 if not (isinstance(actual, ndarray) or isinstance(desired, ndarray)): 127 msg = build_err_msg([actual, desired], err_msg,) 128 if not desired == actual: 129 raise AssertionError(msg) 130 return 131 # Case #4. arrays or equivalent 132 if ((actual is masked) and not (desired is masked)) or \ 133 ((desired is masked) and not (actual is masked)): 134 msg = build_err_msg([actual, desired], 135 err_msg, header='', names=('x', 'y')) 136 raise ValueError(msg) 137 actual = np.array(actual, copy=False, subok=True) 138 desired = np.array(desired, copy=False, subok=True) 139 (actual_dtype, desired_dtype) = (actual.dtype, desired.dtype) 140 if actual_dtype.char == "S" and desired_dtype.char == "S": 141 return _assert_equal_on_sequences(actual.tolist(), 142 desired.tolist(), 143 err_msg='') 144 return assert_array_equal(actual, desired, err_msg) 145 146 147def fail_if_equal(actual, desired, err_msg='',): 148 """ 149 Raises an assertion error if two items are equal. 150 151 """ 152 if isinstance(desired, dict): 153 if not isinstance(actual, dict): 154 raise AssertionError(repr(type(actual))) 155 fail_if_equal(len(actual), len(desired), err_msg) 156 for k, i in desired.items(): 157 if k not in actual: 158 raise AssertionError(repr(k)) 159 fail_if_equal(actual[k], desired[k], f'key={k!r}\n{err_msg}') 160 return 161 if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)): 162 fail_if_equal(len(actual), len(desired), err_msg) 163 for k in range(len(desired)): 164 fail_if_equal(actual[k], desired[k], f'item={k!r}\n{err_msg}') 165 return 166 if isinstance(actual, np.ndarray) or isinstance(desired, np.ndarray): 167 return fail_if_array_equal(actual, desired, err_msg) 168 msg = build_err_msg([actual, desired], err_msg) 169 if not desired != actual: 170 raise AssertionError(msg) 171 172 173assert_not_equal = fail_if_equal 174 175 176def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True): 177 """ 178 Asserts that two items are almost equal. 179 180 The test is equivalent to abs(desired-actual) < 0.5 * 10**(-decimal). 181 182 """ 183 if isinstance(actual, np.ndarray) or isinstance(desired, np.ndarray): 184 return assert_array_almost_equal(actual, desired, decimal=decimal, 185 err_msg=err_msg, verbose=verbose) 186 msg = build_err_msg([actual, desired], 187 err_msg=err_msg, verbose=verbose) 188 if not round(abs(desired - actual), decimal) == 0: 189 raise AssertionError(msg) 190 191 192assert_close = assert_almost_equal 193 194 195def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', 196 fill_value=True): 197 """ 198 Asserts that comparison between two masked arrays is satisfied. 199 200 The comparison is elementwise. 201 202 """ 203 # Allocate a common mask and refill 204 m = mask_or(getmask(x), getmask(y)) 205 x = masked_array(x, copy=False, mask=m, keep_mask=False, subok=False) 206 y = masked_array(y, copy=False, mask=m, keep_mask=False, subok=False) 207 if ((x is masked) and not (y is masked)) or \ 208 ((y is masked) and not (x is masked)): 209 msg = build_err_msg([x, y], err_msg=err_msg, verbose=verbose, 210 header=header, names=('x', 'y')) 211 raise ValueError(msg) 212 # OK, now run the basic tests on filled versions 213 return np.testing.assert_array_compare(comparison, 214 x.filled(fill_value), 215 y.filled(fill_value), 216 err_msg=err_msg, 217 verbose=verbose, header=header) 218 219 220def assert_array_equal(x, y, err_msg='', verbose=True): 221 """ 222 Checks the elementwise equality of two masked arrays. 223 224 """ 225 assert_array_compare(operator.__eq__, x, y, 226 err_msg=err_msg, verbose=verbose, 227 header='Arrays are not equal') 228 229 230def fail_if_array_equal(x, y, err_msg='', verbose=True): 231 """ 232 Raises an assertion error if two masked arrays are not equal elementwise. 233 234 """ 235 def compare(x, y): 236 return (not np.alltrue(approx(x, y))) 237 assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, 238 header='Arrays are not equal') 239 240 241def assert_array_approx_equal(x, y, decimal=6, err_msg='', verbose=True): 242 """ 243 Checks the equality of two masked arrays, up to given number odecimals. 244 245 The equality is checked elementwise. 246 247 """ 248 def compare(x, y): 249 "Returns the result of the loose comparison between x and y)." 250 return approx(x, y, rtol=10. ** -decimal) 251 assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, 252 header='Arrays are not almost equal') 253 254 255def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): 256 """ 257 Checks the equality of two masked arrays, up to given number odecimals. 258 259 The equality is checked elementwise. 260 261 """ 262 def compare(x, y): 263 "Returns the result of the loose comparison between x and y)." 264 return almost(x, y, decimal) 265 assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, 266 header='Arrays are not almost equal') 267 268 269def assert_array_less(x, y, err_msg='', verbose=True): 270 """ 271 Checks that x is smaller than y elementwise. 272 273 """ 274 assert_array_compare(operator.__lt__, x, y, 275 err_msg=err_msg, verbose=verbose, 276 header='Arrays are not less-ordered') 277 278 279def assert_mask_equal(m1, m2, err_msg=''): 280 """ 281 Asserts the equality of two masks. 282 283 """ 284 if m1 is nomask: 285 assert_(m2 is nomask) 286 if m2 is nomask: 287 assert_(m1 is nomask) 288 assert_array_equal(m1, m2, err_msg=err_msg) 289