1"""Miscellaneous functions for testing masked arrays and subclasses
2
3:author: Pierre Gerard-Marchant
4:contact: pierregm_at_uga_dot_edu
5:version: $Id: testutils.py 3529 2007-11-13 08:01:14Z jarrod.millman $
6
7"""
8import operator
9
10import numpy as np
11from numpy import ndarray, float_
12import numpy.core.umath as umath
13import numpy.testing
14from numpy.testing import (
15    assert_, assert_allclose, assert_array_almost_equal_nulp,
16    assert_raises, build_err_msg
17    )
18from .core import mask_or, getmask, masked_array, nomask, masked, filled
19
20__all__masked = [
21    'almost', 'approx', 'assert_almost_equal', 'assert_array_almost_equal',
22    'assert_array_approx_equal', 'assert_array_compare',
23    'assert_array_equal', 'assert_array_less', 'assert_close',
24    'assert_equal', 'assert_equal_records', 'assert_mask_equal',
25    'assert_not_equal', 'fail_if_array_equal',
26    ]
27
28# Include some normal test functions to avoid breaking other projects who
29# have mistakenly included them from this file. SciPy is one. That is
30# unfortunate, as some of these functions are not intended to work with
31# masked arrays. But there was no way to tell before.
32from unittest import TestCase
33__some__from_testing = [
34    'TestCase', 'assert_', 'assert_allclose', 'assert_array_almost_equal_nulp',
35    'assert_raises'
36    ]
37
38__all__ = __all__masked + __some__from_testing
39
40
41def approx(a, b, fill_value=True, rtol=1e-5, atol=1e-8):
42    """
43    Returns true if all components of a and b are equal to given tolerances.
44
45    If fill_value is True, masked values considered equal. Otherwise,
46    masked values are considered unequal.  The relative error rtol should
47    be positive and << 1.0 The absolute error atol comes into play for
48    those elements of b that are very small or zero; it says how small a
49    must be also.
50
51    """
52    m = mask_or(getmask(a), getmask(b))
53    d1 = filled(a)
54    d2 = filled(b)
55    if d1.dtype.char == "O" or d2.dtype.char == "O":
56        return np.equal(d1, d2).ravel()
57    x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_)
58    y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_)
59    d = np.less_equal(umath.absolute(x - y), atol + rtol * umath.absolute(y))
60    return d.ravel()
61
62
63def almost(a, b, decimal=6, fill_value=True):
64    """
65    Returns True if a and b are equal up to decimal places.
66
67    If fill_value is True, masked values considered equal. Otherwise,
68    masked values are considered unequal.
69
70    """
71    m = mask_or(getmask(a), getmask(b))
72    d1 = filled(a)
73    d2 = filled(b)
74    if d1.dtype.char == "O" or d2.dtype.char == "O":
75        return np.equal(d1, d2).ravel()
76    x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_)
77    y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_)
78    d = np.around(np.abs(x - y), decimal) <= 10.0 ** (-decimal)
79    return d.ravel()
80
81
82def _assert_equal_on_sequences(actual, desired, err_msg=''):
83    """
84    Asserts the equality of two non-array sequences.
85
86    """
87    assert_equal(len(actual), len(desired), err_msg)
88    for k in range(len(desired)):
89        assert_equal(actual[k], desired[k], f'item={k!r}\n{err_msg}')
90    return
91
92
93def assert_equal_records(a, b):
94    """
95    Asserts that two records are equal.
96
97    Pretty crude for now.
98
99    """
100    assert_equal(a.dtype, b.dtype)
101    for f in a.dtype.names:
102        (af, bf) = (operator.getitem(a, f), operator.getitem(b, f))
103        if not (af is masked) and not (bf is masked):
104            assert_equal(operator.getitem(a, f), operator.getitem(b, f))
105    return
106
107
108def assert_equal(actual, desired, err_msg=''):
109    """
110    Asserts that two items are equal.
111
112    """
113    # Case #1: dictionary .....
114    if isinstance(desired, dict):
115        if not isinstance(actual, dict):
116            raise AssertionError(repr(type(actual)))
117        assert_equal(len(actual), len(desired), err_msg)
118        for k, i in desired.items():
119            if k not in actual:
120                raise AssertionError(f"{k} not in {actual}")
121            assert_equal(actual[k], desired[k], f'key={k!r}\n{err_msg}')
122        return
123    # Case #2: lists .....
124    if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)):
125        return _assert_equal_on_sequences(actual, desired, err_msg='')
126    if not (isinstance(actual, ndarray) or isinstance(desired, ndarray)):
127        msg = build_err_msg([actual, desired], err_msg,)
128        if not desired == actual:
129            raise AssertionError(msg)
130        return
131    # Case #4. arrays or equivalent
132    if ((actual is masked) and not (desired is masked)) or \
133            ((desired is masked) and not (actual is masked)):
134        msg = build_err_msg([actual, desired],
135                            err_msg, header='', names=('x', 'y'))
136        raise ValueError(msg)
137    actual = np.array(actual, copy=False, subok=True)
138    desired = np.array(desired, copy=False, subok=True)
139    (actual_dtype, desired_dtype) = (actual.dtype, desired.dtype)
140    if actual_dtype.char == "S" and desired_dtype.char == "S":
141        return _assert_equal_on_sequences(actual.tolist(),
142                                          desired.tolist(),
143                                          err_msg='')
144    return assert_array_equal(actual, desired, err_msg)
145
146
147def fail_if_equal(actual, desired, err_msg='',):
148    """
149    Raises an assertion error if two items are equal.
150
151    """
152    if isinstance(desired, dict):
153        if not isinstance(actual, dict):
154            raise AssertionError(repr(type(actual)))
155        fail_if_equal(len(actual), len(desired), err_msg)
156        for k, i in desired.items():
157            if k not in actual:
158                raise AssertionError(repr(k))
159            fail_if_equal(actual[k], desired[k], f'key={k!r}\n{err_msg}')
160        return
161    if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)):
162        fail_if_equal(len(actual), len(desired), err_msg)
163        for k in range(len(desired)):
164            fail_if_equal(actual[k], desired[k], f'item={k!r}\n{err_msg}')
165        return
166    if isinstance(actual, np.ndarray) or isinstance(desired, np.ndarray):
167        return fail_if_array_equal(actual, desired, err_msg)
168    msg = build_err_msg([actual, desired], err_msg)
169    if not desired != actual:
170        raise AssertionError(msg)
171
172
173assert_not_equal = fail_if_equal
174
175
176def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True):
177    """
178    Asserts that two items are almost equal.
179
180    The test is equivalent to abs(desired-actual) < 0.5 * 10**(-decimal).
181
182    """
183    if isinstance(actual, np.ndarray) or isinstance(desired, np.ndarray):
184        return assert_array_almost_equal(actual, desired, decimal=decimal,
185                                         err_msg=err_msg, verbose=verbose)
186    msg = build_err_msg([actual, desired],
187                        err_msg=err_msg, verbose=verbose)
188    if not round(abs(desired - actual), decimal) == 0:
189        raise AssertionError(msg)
190
191
192assert_close = assert_almost_equal
193
194
195def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
196                         fill_value=True):
197    """
198    Asserts that comparison between two masked arrays is satisfied.
199
200    The comparison is elementwise.
201
202    """
203    # Allocate a common mask and refill
204    m = mask_or(getmask(x), getmask(y))
205    x = masked_array(x, copy=False, mask=m, keep_mask=False, subok=False)
206    y = masked_array(y, copy=False, mask=m, keep_mask=False, subok=False)
207    if ((x is masked) and not (y is masked)) or \
208            ((y is masked) and not (x is masked)):
209        msg = build_err_msg([x, y], err_msg=err_msg, verbose=verbose,
210                            header=header, names=('x', 'y'))
211        raise ValueError(msg)
212    # OK, now run the basic tests on filled versions
213    return np.testing.assert_array_compare(comparison,
214                                           x.filled(fill_value),
215                                           y.filled(fill_value),
216                                           err_msg=err_msg,
217                                           verbose=verbose, header=header)
218
219
220def assert_array_equal(x, y, err_msg='', verbose=True):
221    """
222    Checks the elementwise equality of two masked arrays.
223
224    """
225    assert_array_compare(operator.__eq__, x, y,
226                         err_msg=err_msg, verbose=verbose,
227                         header='Arrays are not equal')
228
229
230def fail_if_array_equal(x, y, err_msg='', verbose=True):
231    """
232    Raises an assertion error if two masked arrays are not equal elementwise.
233
234    """
235    def compare(x, y):
236        return (not np.alltrue(approx(x, y)))
237    assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
238                         header='Arrays are not equal')
239
240
241def assert_array_approx_equal(x, y, decimal=6, err_msg='', verbose=True):
242    """
243    Checks the equality of two masked arrays, up to given number odecimals.
244
245    The equality is checked elementwise.
246
247    """
248    def compare(x, y):
249        "Returns the result of the loose comparison between x and y)."
250        return approx(x, y, rtol=10. ** -decimal)
251    assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
252                         header='Arrays are not almost equal')
253
254
255def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
256    """
257    Checks the equality of two masked arrays, up to given number odecimals.
258
259    The equality is checked elementwise.
260
261    """
262    def compare(x, y):
263        "Returns the result of the loose comparison between x and y)."
264        return almost(x, y, decimal)
265    assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
266                         header='Arrays are not almost equal')
267
268
269def assert_array_less(x, y, err_msg='', verbose=True):
270    """
271    Checks that x is smaller than y elementwise.
272
273    """
274    assert_array_compare(operator.__lt__, x, y,
275                         err_msg=err_msg, verbose=verbose,
276                         header='Arrays are not less-ordered')
277
278
279def assert_mask_equal(m1, m2, err_msg=''):
280    """
281    Asserts the equality of two masks.
282
283    """
284    if m1 is nomask:
285        assert_(m2 is nomask)
286    if m2 is nomask:
287        assert_(m1 is nomask)
288    assert_array_equal(m1, m2, err_msg=err_msg)
289