1import warnings
2import pytest
3
4import numpy as np
5from numpy.lib.nanfunctions import _nan_mask, _replace_nan
6from numpy.testing import (
7    assert_, assert_equal, assert_almost_equal, assert_no_warnings,
8    assert_raises, assert_array_equal, suppress_warnings
9    )
10
11
12# Test data
13_ndat = np.array([[0.6244, np.nan, 0.2692, 0.0116, np.nan, 0.1170],
14                  [0.5351, -0.9403, np.nan, 0.2100, 0.4759, 0.2833],
15                  [np.nan, np.nan, np.nan, 0.1042, np.nan, -0.5954],
16                  [0.1610, np.nan, np.nan, 0.1859, 0.3146, np.nan]])
17
18
19# Rows of _ndat with nans removed
20_rdat = [np.array([0.6244, 0.2692, 0.0116, 0.1170]),
21         np.array([0.5351, -0.9403, 0.2100, 0.4759, 0.2833]),
22         np.array([0.1042, -0.5954]),
23         np.array([0.1610, 0.1859, 0.3146])]
24
25# Rows of _ndat with nans converted to ones
26_ndat_ones = np.array([[0.6244, 1.0, 0.2692, 0.0116, 1.0, 0.1170],
27                       [0.5351, -0.9403, 1.0, 0.2100, 0.4759, 0.2833],
28                       [1.0, 1.0, 1.0, 0.1042, 1.0, -0.5954],
29                       [0.1610, 1.0, 1.0, 0.1859, 0.3146, 1.0]])
30
31# Rows of _ndat with nans converted to zeros
32_ndat_zeros = np.array([[0.6244, 0.0, 0.2692, 0.0116, 0.0, 0.1170],
33                        [0.5351, -0.9403, 0.0, 0.2100, 0.4759, 0.2833],
34                        [0.0, 0.0, 0.0, 0.1042, 0.0, -0.5954],
35                        [0.1610, 0.0, 0.0, 0.1859, 0.3146, 0.0]])
36
37
38class TestNanFunctions_MinMax:
39
40    nanfuncs = [np.nanmin, np.nanmax]
41    stdfuncs = [np.min, np.max]
42
43    def test_mutation(self):
44        # Check that passed array is not modified.
45        ndat = _ndat.copy()
46        for f in self.nanfuncs:
47            f(ndat)
48            assert_equal(ndat, _ndat)
49
50    def test_keepdims(self):
51        mat = np.eye(3)
52        for nf, rf in zip(self.nanfuncs, self.stdfuncs):
53            for axis in [None, 0, 1]:
54                tgt = rf(mat, axis=axis, keepdims=True)
55                res = nf(mat, axis=axis, keepdims=True)
56                assert_(res.ndim == tgt.ndim)
57
58    def test_out(self):
59        mat = np.eye(3)
60        for nf, rf in zip(self.nanfuncs, self.stdfuncs):
61            resout = np.zeros(3)
62            tgt = rf(mat, axis=1)
63            res = nf(mat, axis=1, out=resout)
64            assert_almost_equal(res, resout)
65            assert_almost_equal(res, tgt)
66
67    def test_dtype_from_input(self):
68        codes = 'efdgFDG'
69        for nf, rf in zip(self.nanfuncs, self.stdfuncs):
70            for c in codes:
71                mat = np.eye(3, dtype=c)
72                tgt = rf(mat, axis=1).dtype.type
73                res = nf(mat, axis=1).dtype.type
74                assert_(res is tgt)
75                # scalar case
76                tgt = rf(mat, axis=None).dtype.type
77                res = nf(mat, axis=None).dtype.type
78                assert_(res is tgt)
79
80    def test_result_values(self):
81        for nf, rf in zip(self.nanfuncs, self.stdfuncs):
82            tgt = [rf(d) for d in _rdat]
83            res = nf(_ndat, axis=1)
84            assert_almost_equal(res, tgt)
85
86    def test_allnans(self):
87        mat = np.array([np.nan]*9).reshape(3, 3)
88        for f in self.nanfuncs:
89            for axis in [None, 0, 1]:
90                with warnings.catch_warnings(record=True) as w:
91                    warnings.simplefilter('always')
92                    assert_(np.isnan(f(mat, axis=axis)).all())
93                    assert_(len(w) == 1, 'no warning raised')
94                    assert_(issubclass(w[0].category, RuntimeWarning))
95            # Check scalars
96            with warnings.catch_warnings(record=True) as w:
97                warnings.simplefilter('always')
98                assert_(np.isnan(f(np.nan)))
99                assert_(len(w) == 1, 'no warning raised')
100                assert_(issubclass(w[0].category, RuntimeWarning))
101
102    def test_masked(self):
103        mat = np.ma.fix_invalid(_ndat)
104        msk = mat._mask.copy()
105        for f in [np.nanmin]:
106            res = f(mat, axis=1)
107            tgt = f(_ndat, axis=1)
108            assert_equal(res, tgt)
109            assert_equal(mat._mask, msk)
110            assert_(not np.isinf(mat).any())
111
112    def test_scalar(self):
113        for f in self.nanfuncs:
114            assert_(f(0.) == 0.)
115
116    def test_subclass(self):
117        class MyNDArray(np.ndarray):
118            pass
119
120        # Check that it works and that type and
121        # shape are preserved
122        mine = np.eye(3).view(MyNDArray)
123        for f in self.nanfuncs:
124            res = f(mine, axis=0)
125            assert_(isinstance(res, MyNDArray))
126            assert_(res.shape == (3,))
127            res = f(mine, axis=1)
128            assert_(isinstance(res, MyNDArray))
129            assert_(res.shape == (3,))
130            res = f(mine)
131            assert_(res.shape == ())
132
133        # check that rows of nan are dealt with for subclasses (#4628)
134        mine[1] = np.nan
135        for f in self.nanfuncs:
136            with warnings.catch_warnings(record=True) as w:
137                warnings.simplefilter('always')
138                res = f(mine, axis=0)
139                assert_(isinstance(res, MyNDArray))
140                assert_(not np.any(np.isnan(res)))
141                assert_(len(w) == 0)
142
143            with warnings.catch_warnings(record=True) as w:
144                warnings.simplefilter('always')
145                res = f(mine, axis=1)
146                assert_(isinstance(res, MyNDArray))
147                assert_(np.isnan(res[1]) and not np.isnan(res[0])
148                        and not np.isnan(res[2]))
149                assert_(len(w) == 1, 'no warning raised')
150                assert_(issubclass(w[0].category, RuntimeWarning))
151
152            with warnings.catch_warnings(record=True) as w:
153                warnings.simplefilter('always')
154                res = f(mine)
155                assert_(res.shape == ())
156                assert_(res != np.nan)
157                assert_(len(w) == 0)
158
159    def test_object_array(self):
160        arr = np.array([[1.0, 2.0], [np.nan, 4.0], [np.nan, np.nan]], dtype=object)
161        assert_equal(np.nanmin(arr), 1.0)
162        assert_equal(np.nanmin(arr, axis=0), [1.0, 2.0])
163
164        with warnings.catch_warnings(record=True) as w:
165            warnings.simplefilter('always')
166            # assert_equal does not work on object arrays of nan
167            assert_equal(list(np.nanmin(arr, axis=1)), [1.0, 4.0, np.nan])
168            assert_(len(w) == 1, 'no warning raised')
169            assert_(issubclass(w[0].category, RuntimeWarning))
170
171
172class TestNanFunctions_ArgminArgmax:
173
174    nanfuncs = [np.nanargmin, np.nanargmax]
175
176    def test_mutation(self):
177        # Check that passed array is not modified.
178        ndat = _ndat.copy()
179        for f in self.nanfuncs:
180            f(ndat)
181            assert_equal(ndat, _ndat)
182
183    def test_result_values(self):
184        for f, fcmp in zip(self.nanfuncs, [np.greater, np.less]):
185            for row in _ndat:
186                with suppress_warnings() as sup:
187                    sup.filter(RuntimeWarning, "invalid value encountered in")
188                    ind = f(row)
189                    val = row[ind]
190                    # comparing with NaN is tricky as the result
191                    # is always false except for NaN != NaN
192                    assert_(not np.isnan(val))
193                    assert_(not fcmp(val, row).any())
194                    assert_(not np.equal(val, row[:ind]).any())
195
196    def test_allnans(self):
197        mat = np.array([np.nan]*9).reshape(3, 3)
198        for f in self.nanfuncs:
199            for axis in [None, 0, 1]:
200                assert_raises(ValueError, f, mat, axis=axis)
201            assert_raises(ValueError, f, np.nan)
202
203    def test_empty(self):
204        mat = np.zeros((0, 3))
205        for f in self.nanfuncs:
206            for axis in [0, None]:
207                assert_raises(ValueError, f, mat, axis=axis)
208            for axis in [1]:
209                res = f(mat, axis=axis)
210                assert_equal(res, np.zeros(0))
211
212    def test_scalar(self):
213        for f in self.nanfuncs:
214            assert_(f(0.) == 0.)
215
216    def test_subclass(self):
217        class MyNDArray(np.ndarray):
218            pass
219
220        # Check that it works and that type and
221        # shape are preserved
222        mine = np.eye(3).view(MyNDArray)
223        for f in self.nanfuncs:
224            res = f(mine, axis=0)
225            assert_(isinstance(res, MyNDArray))
226            assert_(res.shape == (3,))
227            res = f(mine, axis=1)
228            assert_(isinstance(res, MyNDArray))
229            assert_(res.shape == (3,))
230            res = f(mine)
231            assert_(res.shape == ())
232
233
234class TestNanFunctions_IntTypes:
235
236    int_types = (np.int8, np.int16, np.int32, np.int64, np.uint8,
237                 np.uint16, np.uint32, np.uint64)
238
239    mat = np.array([127, 39, 93, 87, 46])
240
241    def integer_arrays(self):
242        for dtype in self.int_types:
243            yield self.mat.astype(dtype)
244
245    def test_nanmin(self):
246        tgt = np.min(self.mat)
247        for mat in self.integer_arrays():
248            assert_equal(np.nanmin(mat), tgt)
249
250    def test_nanmax(self):
251        tgt = np.max(self.mat)
252        for mat in self.integer_arrays():
253            assert_equal(np.nanmax(mat), tgt)
254
255    def test_nanargmin(self):
256        tgt = np.argmin(self.mat)
257        for mat in self.integer_arrays():
258            assert_equal(np.nanargmin(mat), tgt)
259
260    def test_nanargmax(self):
261        tgt = np.argmax(self.mat)
262        for mat in self.integer_arrays():
263            assert_equal(np.nanargmax(mat), tgt)
264
265    def test_nansum(self):
266        tgt = np.sum(self.mat)
267        for mat in self.integer_arrays():
268            assert_equal(np.nansum(mat), tgt)
269
270    def test_nanprod(self):
271        tgt = np.prod(self.mat)
272        for mat in self.integer_arrays():
273            assert_equal(np.nanprod(mat), tgt)
274
275    def test_nancumsum(self):
276        tgt = np.cumsum(self.mat)
277        for mat in self.integer_arrays():
278            assert_equal(np.nancumsum(mat), tgt)
279
280    def test_nancumprod(self):
281        tgt = np.cumprod(self.mat)
282        for mat in self.integer_arrays():
283            assert_equal(np.nancumprod(mat), tgt)
284
285    def test_nanmean(self):
286        tgt = np.mean(self.mat)
287        for mat in self.integer_arrays():
288            assert_equal(np.nanmean(mat), tgt)
289
290    def test_nanvar(self):
291        tgt = np.var(self.mat)
292        for mat in self.integer_arrays():
293            assert_equal(np.nanvar(mat), tgt)
294
295        tgt = np.var(mat, ddof=1)
296        for mat in self.integer_arrays():
297            assert_equal(np.nanvar(mat, ddof=1), tgt)
298
299    def test_nanstd(self):
300        tgt = np.std(self.mat)
301        for mat in self.integer_arrays():
302            assert_equal(np.nanstd(mat), tgt)
303
304        tgt = np.std(self.mat, ddof=1)
305        for mat in self.integer_arrays():
306            assert_equal(np.nanstd(mat, ddof=1), tgt)
307
308
309class SharedNanFunctionsTestsMixin:
310    def test_mutation(self):
311        # Check that passed array is not modified.
312        ndat = _ndat.copy()
313        for f in self.nanfuncs:
314            f(ndat)
315            assert_equal(ndat, _ndat)
316
317    def test_keepdims(self):
318        mat = np.eye(3)
319        for nf, rf in zip(self.nanfuncs, self.stdfuncs):
320            for axis in [None, 0, 1]:
321                tgt = rf(mat, axis=axis, keepdims=True)
322                res = nf(mat, axis=axis, keepdims=True)
323                assert_(res.ndim == tgt.ndim)
324
325    def test_out(self):
326        mat = np.eye(3)
327        for nf, rf in zip(self.nanfuncs, self.stdfuncs):
328            resout = np.zeros(3)
329            tgt = rf(mat, axis=1)
330            res = nf(mat, axis=1, out=resout)
331            assert_almost_equal(res, resout)
332            assert_almost_equal(res, tgt)
333
334    def test_dtype_from_dtype(self):
335        mat = np.eye(3)
336        codes = 'efdgFDG'
337        for nf, rf in zip(self.nanfuncs, self.stdfuncs):
338            for c in codes:
339                with suppress_warnings() as sup:
340                    if nf in {np.nanstd, np.nanvar} and c in 'FDG':
341                        # Giving the warning is a small bug, see gh-8000
342                        sup.filter(np.ComplexWarning)
343                    tgt = rf(mat, dtype=np.dtype(c), axis=1).dtype.type
344                    res = nf(mat, dtype=np.dtype(c), axis=1).dtype.type
345                    assert_(res is tgt)
346                    # scalar case
347                    tgt = rf(mat, dtype=np.dtype(c), axis=None).dtype.type
348                    res = nf(mat, dtype=np.dtype(c), axis=None).dtype.type
349                    assert_(res is tgt)
350
351    def test_dtype_from_char(self):
352        mat = np.eye(3)
353        codes = 'efdgFDG'
354        for nf, rf in zip(self.nanfuncs, self.stdfuncs):
355            for c in codes:
356                with suppress_warnings() as sup:
357                    if nf in {np.nanstd, np.nanvar} and c in 'FDG':
358                        # Giving the warning is a small bug, see gh-8000
359                        sup.filter(np.ComplexWarning)
360                    tgt = rf(mat, dtype=c, axis=1).dtype.type
361                    res = nf(mat, dtype=c, axis=1).dtype.type
362                    assert_(res is tgt)
363                    # scalar case
364                    tgt = rf(mat, dtype=c, axis=None).dtype.type
365                    res = nf(mat, dtype=c, axis=None).dtype.type
366                    assert_(res is tgt)
367
368    def test_dtype_from_input(self):
369        codes = 'efdgFDG'
370        for nf, rf in zip(self.nanfuncs, self.stdfuncs):
371            for c in codes:
372                mat = np.eye(3, dtype=c)
373                tgt = rf(mat, axis=1).dtype.type
374                res = nf(mat, axis=1).dtype.type
375                assert_(res is tgt, "res %s, tgt %s" % (res, tgt))
376                # scalar case
377                tgt = rf(mat, axis=None).dtype.type
378                res = nf(mat, axis=None).dtype.type
379                assert_(res is tgt)
380
381    def test_result_values(self):
382        for nf, rf in zip(self.nanfuncs, self.stdfuncs):
383            tgt = [rf(d) for d in _rdat]
384            res = nf(_ndat, axis=1)
385            assert_almost_equal(res, tgt)
386
387    def test_scalar(self):
388        for f in self.nanfuncs:
389            assert_(f(0.) == 0.)
390
391    def test_subclass(self):
392        class MyNDArray(np.ndarray):
393            pass
394
395        # Check that it works and that type and
396        # shape are preserved
397        array = np.eye(3)
398        mine = array.view(MyNDArray)
399        for f in self.nanfuncs:
400            expected_shape = f(array, axis=0).shape
401            res = f(mine, axis=0)
402            assert_(isinstance(res, MyNDArray))
403            assert_(res.shape == expected_shape)
404            expected_shape = f(array, axis=1).shape
405            res = f(mine, axis=1)
406            assert_(isinstance(res, MyNDArray))
407            assert_(res.shape == expected_shape)
408            expected_shape = f(array).shape
409            res = f(mine)
410            assert_(isinstance(res, MyNDArray))
411            assert_(res.shape == expected_shape)
412
413
414class TestNanFunctions_SumProd(SharedNanFunctionsTestsMixin):
415
416    nanfuncs = [np.nansum, np.nanprod]
417    stdfuncs = [np.sum, np.prod]
418
419    def test_allnans(self):
420        # Check for FutureWarning
421        with warnings.catch_warnings(record=True) as w:
422            warnings.simplefilter('always')
423            res = np.nansum([np.nan]*3, axis=None)
424            assert_(res == 0, 'result is not 0')
425            assert_(len(w) == 0, 'warning raised')
426            # Check scalar
427            res = np.nansum(np.nan)
428            assert_(res == 0, 'result is not 0')
429            assert_(len(w) == 0, 'warning raised')
430            # Check there is no warning for not all-nan
431            np.nansum([0]*3, axis=None)
432            assert_(len(w) == 0, 'unwanted warning raised')
433
434    def test_empty(self):
435        for f, tgt_value in zip([np.nansum, np.nanprod], [0, 1]):
436            mat = np.zeros((0, 3))
437            tgt = [tgt_value]*3
438            res = f(mat, axis=0)
439            assert_equal(res, tgt)
440            tgt = []
441            res = f(mat, axis=1)
442            assert_equal(res, tgt)
443            tgt = tgt_value
444            res = f(mat, axis=None)
445            assert_equal(res, tgt)
446
447
448class TestNanFunctions_CumSumProd(SharedNanFunctionsTestsMixin):
449
450    nanfuncs = [np.nancumsum, np.nancumprod]
451    stdfuncs = [np.cumsum, np.cumprod]
452
453    def test_allnans(self):
454        for f, tgt_value in zip(self.nanfuncs, [0, 1]):
455            # Unlike other nan-functions, sum/prod/cumsum/cumprod don't warn on all nan input
456            with assert_no_warnings():
457                res = f([np.nan]*3, axis=None)
458                tgt = tgt_value*np.ones((3))
459                assert_(np.array_equal(res, tgt), 'result is not %s * np.ones((3))' % (tgt_value))
460                # Check scalar
461                res = f(np.nan)
462                tgt = tgt_value*np.ones((1))
463                assert_(np.array_equal(res, tgt), 'result is not %s * np.ones((1))' % (tgt_value))
464                # Check there is no warning for not all-nan
465                f([0]*3, axis=None)
466
467    def test_empty(self):
468        for f, tgt_value in zip(self.nanfuncs, [0, 1]):
469            mat = np.zeros((0, 3))
470            tgt = tgt_value*np.ones((0, 3))
471            res = f(mat, axis=0)
472            assert_equal(res, tgt)
473            tgt = mat
474            res = f(mat, axis=1)
475            assert_equal(res, tgt)
476            tgt = np.zeros((0))
477            res = f(mat, axis=None)
478            assert_equal(res, tgt)
479
480    def test_keepdims(self):
481        for f, g in zip(self.nanfuncs, self.stdfuncs):
482            mat = np.eye(3)
483            for axis in [None, 0, 1]:
484                tgt = f(mat, axis=axis, out=None)
485                res = g(mat, axis=axis, out=None)
486                assert_(res.ndim == tgt.ndim)
487
488        for f in self.nanfuncs:
489            d = np.ones((3, 5, 7, 11))
490            # Randomly set some elements to NaN:
491            rs = np.random.RandomState(0)
492            d[rs.rand(*d.shape) < 0.5] = np.nan
493            res = f(d, axis=None)
494            assert_equal(res.shape, (1155,))
495            for axis in np.arange(4):
496                res = f(d, axis=axis)
497                assert_equal(res.shape, (3, 5, 7, 11))
498
499    def test_result_values(self):
500        for axis in (-2, -1, 0, 1, None):
501            tgt = np.cumprod(_ndat_ones, axis=axis)
502            res = np.nancumprod(_ndat, axis=axis)
503            assert_almost_equal(res, tgt)
504            tgt = np.cumsum(_ndat_zeros,axis=axis)
505            res = np.nancumsum(_ndat, axis=axis)
506            assert_almost_equal(res, tgt)
507
508    def test_out(self):
509        mat = np.eye(3)
510        for nf, rf in zip(self.nanfuncs, self.stdfuncs):
511            resout = np.eye(3)
512            for axis in (-2, -1, 0, 1):
513                tgt = rf(mat, axis=axis)
514                res = nf(mat, axis=axis, out=resout)
515                assert_almost_equal(res, resout)
516                assert_almost_equal(res, tgt)
517
518
519class TestNanFunctions_MeanVarStd(SharedNanFunctionsTestsMixin):
520
521    nanfuncs = [np.nanmean, np.nanvar, np.nanstd]
522    stdfuncs = [np.mean, np.var, np.std]
523
524    def test_dtype_error(self):
525        for f in self.nanfuncs:
526            for dtype in [np.bool_, np.int_, np.object_]:
527                assert_raises(TypeError, f, _ndat, axis=1, dtype=dtype)
528
529    def test_out_dtype_error(self):
530        for f in self.nanfuncs:
531            for dtype in [np.bool_, np.int_, np.object_]:
532                out = np.empty(_ndat.shape[0], dtype=dtype)
533                assert_raises(TypeError, f, _ndat, axis=1, out=out)
534
535    def test_ddof(self):
536        nanfuncs = [np.nanvar, np.nanstd]
537        stdfuncs = [np.var, np.std]
538        for nf, rf in zip(nanfuncs, stdfuncs):
539            for ddof in [0, 1]:
540                tgt = [rf(d, ddof=ddof) for d in _rdat]
541                res = nf(_ndat, axis=1, ddof=ddof)
542                assert_almost_equal(res, tgt)
543
544    def test_ddof_too_big(self):
545        nanfuncs = [np.nanvar, np.nanstd]
546        stdfuncs = [np.var, np.std]
547        dsize = [len(d) for d in _rdat]
548        for nf, rf in zip(nanfuncs, stdfuncs):
549            for ddof in range(5):
550                with suppress_warnings() as sup:
551                    sup.record(RuntimeWarning)
552                    sup.filter(np.ComplexWarning)
553                    tgt = [ddof >= d for d in dsize]
554                    res = nf(_ndat, axis=1, ddof=ddof)
555                    assert_equal(np.isnan(res), tgt)
556                    if any(tgt):
557                        assert_(len(sup.log) == 1)
558                    else:
559                        assert_(len(sup.log) == 0)
560
561    def test_allnans(self):
562        mat = np.array([np.nan]*9).reshape(3, 3)
563        for f in self.nanfuncs:
564            for axis in [None, 0, 1]:
565                with warnings.catch_warnings(record=True) as w:
566                    warnings.simplefilter('always')
567                    assert_(np.isnan(f(mat, axis=axis)).all())
568                    assert_(len(w) == 1)
569                    assert_(issubclass(w[0].category, RuntimeWarning))
570                    # Check scalar
571                    assert_(np.isnan(f(np.nan)))
572                    assert_(len(w) == 2)
573                    assert_(issubclass(w[0].category, RuntimeWarning))
574
575    def test_empty(self):
576        mat = np.zeros((0, 3))
577        for f in self.nanfuncs:
578            for axis in [0, None]:
579                with warnings.catch_warnings(record=True) as w:
580                    warnings.simplefilter('always')
581                    assert_(np.isnan(f(mat, axis=axis)).all())
582                    assert_(len(w) == 1)
583                    assert_(issubclass(w[0].category, RuntimeWarning))
584            for axis in [1]:
585                with warnings.catch_warnings(record=True) as w:
586                    warnings.simplefilter('always')
587                    assert_equal(f(mat, axis=axis), np.zeros([]))
588                    assert_(len(w) == 0)
589
590
591class TestNanFunctions_Median:
592
593    def test_mutation(self):
594        # Check that passed array is not modified.
595        ndat = _ndat.copy()
596        np.nanmedian(ndat)
597        assert_equal(ndat, _ndat)
598
599    def test_keepdims(self):
600        mat = np.eye(3)
601        for axis in [None, 0, 1]:
602            tgt = np.median(mat, axis=axis, out=None, overwrite_input=False)
603            res = np.nanmedian(mat, axis=axis, out=None, overwrite_input=False)
604            assert_(res.ndim == tgt.ndim)
605
606        d = np.ones((3, 5, 7, 11))
607        # Randomly set some elements to NaN:
608        w = np.random.random((4, 200)) * np.array(d.shape)[:, None]
609        w = w.astype(np.intp)
610        d[tuple(w)] = np.nan
611        with suppress_warnings() as sup:
612            sup.filter(RuntimeWarning)
613            res = np.nanmedian(d, axis=None, keepdims=True)
614            assert_equal(res.shape, (1, 1, 1, 1))
615            res = np.nanmedian(d, axis=(0, 1), keepdims=True)
616            assert_equal(res.shape, (1, 1, 7, 11))
617            res = np.nanmedian(d, axis=(0, 3), keepdims=True)
618            assert_equal(res.shape, (1, 5, 7, 1))
619            res = np.nanmedian(d, axis=(1,), keepdims=True)
620            assert_equal(res.shape, (3, 1, 7, 11))
621            res = np.nanmedian(d, axis=(0, 1, 2, 3), keepdims=True)
622            assert_equal(res.shape, (1, 1, 1, 1))
623            res = np.nanmedian(d, axis=(0, 1, 3), keepdims=True)
624            assert_equal(res.shape, (1, 1, 7, 1))
625
626    def test_out(self):
627        mat = np.random.rand(3, 3)
628        nan_mat = np.insert(mat, [0, 2], np.nan, axis=1)
629        resout = np.zeros(3)
630        tgt = np.median(mat, axis=1)
631        res = np.nanmedian(nan_mat, axis=1, out=resout)
632        assert_almost_equal(res, resout)
633        assert_almost_equal(res, tgt)
634        # 0-d output:
635        resout = np.zeros(())
636        tgt = np.median(mat, axis=None)
637        res = np.nanmedian(nan_mat, axis=None, out=resout)
638        assert_almost_equal(res, resout)
639        assert_almost_equal(res, tgt)
640        res = np.nanmedian(nan_mat, axis=(0, 1), out=resout)
641        assert_almost_equal(res, resout)
642        assert_almost_equal(res, tgt)
643
644    def test_small_large(self):
645        # test the small and large code paths, current cutoff 400 elements
646        for s in [5, 20, 51, 200, 1000]:
647            d = np.random.randn(4, s)
648            # Randomly set some elements to NaN:
649            w = np.random.randint(0, d.size, size=d.size // 5)
650            d.ravel()[w] = np.nan
651            d[:,0] = 1.  # ensure at least one good value
652            # use normal median without nans to compare
653            tgt = []
654            for x in d:
655                nonan = np.compress(~np.isnan(x), x)
656                tgt.append(np.median(nonan, overwrite_input=True))
657
658            assert_array_equal(np.nanmedian(d, axis=-1), tgt)
659
660    def test_result_values(self):
661            tgt = [np.median(d) for d in _rdat]
662            res = np.nanmedian(_ndat, axis=1)
663            assert_almost_equal(res, tgt)
664
665    def test_allnans(self):
666        mat = np.array([np.nan]*9).reshape(3, 3)
667        for axis in [None, 0, 1]:
668            with suppress_warnings() as sup:
669                sup.record(RuntimeWarning)
670
671                assert_(np.isnan(np.nanmedian(mat, axis=axis)).all())
672                if axis is None:
673                    assert_(len(sup.log) == 1)
674                else:
675                    assert_(len(sup.log) == 3)
676                # Check scalar
677                assert_(np.isnan(np.nanmedian(np.nan)))
678                if axis is None:
679                    assert_(len(sup.log) == 2)
680                else:
681                    assert_(len(sup.log) == 4)
682
683    def test_empty(self):
684        mat = np.zeros((0, 3))
685        for axis in [0, None]:
686            with warnings.catch_warnings(record=True) as w:
687                warnings.simplefilter('always')
688                assert_(np.isnan(np.nanmedian(mat, axis=axis)).all())
689                assert_(len(w) == 1)
690                assert_(issubclass(w[0].category, RuntimeWarning))
691        for axis in [1]:
692            with warnings.catch_warnings(record=True) as w:
693                warnings.simplefilter('always')
694                assert_equal(np.nanmedian(mat, axis=axis), np.zeros([]))
695                assert_(len(w) == 0)
696
697    def test_scalar(self):
698        assert_(np.nanmedian(0.) == 0.)
699
700    def test_extended_axis_invalid(self):
701        d = np.ones((3, 5, 7, 11))
702        assert_raises(np.AxisError, np.nanmedian, d, axis=-5)
703        assert_raises(np.AxisError, np.nanmedian, d, axis=(0, -5))
704        assert_raises(np.AxisError, np.nanmedian, d, axis=4)
705        assert_raises(np.AxisError, np.nanmedian, d, axis=(0, 4))
706        assert_raises(ValueError, np.nanmedian, d, axis=(1, 1))
707
708    def test_float_special(self):
709        with suppress_warnings() as sup:
710            sup.filter(RuntimeWarning)
711            for inf in [np.inf, -np.inf]:
712                a = np.array([[inf,  np.nan], [np.nan, np.nan]])
713                assert_equal(np.nanmedian(a, axis=0), [inf,  np.nan])
714                assert_equal(np.nanmedian(a, axis=1), [inf,  np.nan])
715                assert_equal(np.nanmedian(a), inf)
716
717                # minimum fill value check
718                a = np.array([[np.nan, np.nan, inf],
719                             [np.nan, np.nan, inf]])
720                assert_equal(np.nanmedian(a), inf)
721                assert_equal(np.nanmedian(a, axis=0), [np.nan, np.nan, inf])
722                assert_equal(np.nanmedian(a, axis=1), inf)
723
724                # no mask path
725                a = np.array([[inf, inf], [inf, inf]])
726                assert_equal(np.nanmedian(a, axis=1), inf)
727
728                a = np.array([[inf, 7, -inf, -9],
729                              [-10, np.nan, np.nan, 5],
730                              [4, np.nan, np.nan, inf]],
731                              dtype=np.float32)
732                if inf > 0:
733                    assert_equal(np.nanmedian(a, axis=0), [4., 7., -inf, 5.])
734                    assert_equal(np.nanmedian(a), 4.5)
735                else:
736                    assert_equal(np.nanmedian(a, axis=0), [-10., 7., -inf, -9.])
737                    assert_equal(np.nanmedian(a), -2.5)
738                assert_equal(np.nanmedian(a, axis=-1), [-1., -2.5, inf])
739
740                for i in range(0, 10):
741                    for j in range(1, 10):
742                        a = np.array([([np.nan] * i) + ([inf] * j)] * 2)
743                        assert_equal(np.nanmedian(a), inf)
744                        assert_equal(np.nanmedian(a, axis=1), inf)
745                        assert_equal(np.nanmedian(a, axis=0),
746                                     ([np.nan] * i) + [inf] * j)
747
748                        a = np.array([([np.nan] * i) + ([-inf] * j)] * 2)
749                        assert_equal(np.nanmedian(a), -inf)
750                        assert_equal(np.nanmedian(a, axis=1), -inf)
751                        assert_equal(np.nanmedian(a, axis=0),
752                                     ([np.nan] * i) + [-inf] * j)
753
754
755class TestNanFunctions_Percentile:
756
757    def test_mutation(self):
758        # Check that passed array is not modified.
759        ndat = _ndat.copy()
760        np.nanpercentile(ndat, 30)
761        assert_equal(ndat, _ndat)
762
763    def test_keepdims(self):
764        mat = np.eye(3)
765        for axis in [None, 0, 1]:
766            tgt = np.percentile(mat, 70, axis=axis, out=None,
767                                overwrite_input=False)
768            res = np.nanpercentile(mat, 70, axis=axis, out=None,
769                                   overwrite_input=False)
770            assert_(res.ndim == tgt.ndim)
771
772        d = np.ones((3, 5, 7, 11))
773        # Randomly set some elements to NaN:
774        w = np.random.random((4, 200)) * np.array(d.shape)[:, None]
775        w = w.astype(np.intp)
776        d[tuple(w)] = np.nan
777        with suppress_warnings() as sup:
778            sup.filter(RuntimeWarning)
779            res = np.nanpercentile(d, 90, axis=None, keepdims=True)
780            assert_equal(res.shape, (1, 1, 1, 1))
781            res = np.nanpercentile(d, 90, axis=(0, 1), keepdims=True)
782            assert_equal(res.shape, (1, 1, 7, 11))
783            res = np.nanpercentile(d, 90, axis=(0, 3), keepdims=True)
784            assert_equal(res.shape, (1, 5, 7, 1))
785            res = np.nanpercentile(d, 90, axis=(1,), keepdims=True)
786            assert_equal(res.shape, (3, 1, 7, 11))
787            res = np.nanpercentile(d, 90, axis=(0, 1, 2, 3), keepdims=True)
788            assert_equal(res.shape, (1, 1, 1, 1))
789            res = np.nanpercentile(d, 90, axis=(0, 1, 3), keepdims=True)
790            assert_equal(res.shape, (1, 1, 7, 1))
791
792    def test_out(self):
793        mat = np.random.rand(3, 3)
794        nan_mat = np.insert(mat, [0, 2], np.nan, axis=1)
795        resout = np.zeros(3)
796        tgt = np.percentile(mat, 42, axis=1)
797        res = np.nanpercentile(nan_mat, 42, axis=1, out=resout)
798        assert_almost_equal(res, resout)
799        assert_almost_equal(res, tgt)
800        # 0-d output:
801        resout = np.zeros(())
802        tgt = np.percentile(mat, 42, axis=None)
803        res = np.nanpercentile(nan_mat, 42, axis=None, out=resout)
804        assert_almost_equal(res, resout)
805        assert_almost_equal(res, tgt)
806        res = np.nanpercentile(nan_mat, 42, axis=(0, 1), out=resout)
807        assert_almost_equal(res, resout)
808        assert_almost_equal(res, tgt)
809
810    def test_result_values(self):
811        tgt = [np.percentile(d, 28) for d in _rdat]
812        res = np.nanpercentile(_ndat, 28, axis=1)
813        assert_almost_equal(res, tgt)
814        # Transpose the array to fit the output convention of numpy.percentile
815        tgt = np.transpose([np.percentile(d, (28, 98)) for d in _rdat])
816        res = np.nanpercentile(_ndat, (28, 98), axis=1)
817        assert_almost_equal(res, tgt)
818
819    def test_allnans(self):
820        mat = np.array([np.nan]*9).reshape(3, 3)
821        for axis in [None, 0, 1]:
822            with warnings.catch_warnings(record=True) as w:
823                warnings.simplefilter('always')
824                assert_(np.isnan(np.nanpercentile(mat, 60, axis=axis)).all())
825                if axis is None:
826                    assert_(len(w) == 1)
827                else:
828                    assert_(len(w) == 3)
829                assert_(issubclass(w[0].category, RuntimeWarning))
830                # Check scalar
831                assert_(np.isnan(np.nanpercentile(np.nan, 60)))
832                if axis is None:
833                    assert_(len(w) == 2)
834                else:
835                    assert_(len(w) == 4)
836                assert_(issubclass(w[0].category, RuntimeWarning))
837
838    def test_empty(self):
839        mat = np.zeros((0, 3))
840        for axis in [0, None]:
841            with warnings.catch_warnings(record=True) as w:
842                warnings.simplefilter('always')
843                assert_(np.isnan(np.nanpercentile(mat, 40, axis=axis)).all())
844                assert_(len(w) == 1)
845                assert_(issubclass(w[0].category, RuntimeWarning))
846        for axis in [1]:
847            with warnings.catch_warnings(record=True) as w:
848                warnings.simplefilter('always')
849                assert_equal(np.nanpercentile(mat, 40, axis=axis), np.zeros([]))
850                assert_(len(w) == 0)
851
852    def test_scalar(self):
853        assert_equal(np.nanpercentile(0., 100), 0.)
854        a = np.arange(6)
855        r = np.nanpercentile(a, 50, axis=0)
856        assert_equal(r, 2.5)
857        assert_(np.isscalar(r))
858
859    def test_extended_axis_invalid(self):
860        d = np.ones((3, 5, 7, 11))
861        assert_raises(np.AxisError, np.nanpercentile, d, q=5, axis=-5)
862        assert_raises(np.AxisError, np.nanpercentile, d, q=5, axis=(0, -5))
863        assert_raises(np.AxisError, np.nanpercentile, d, q=5, axis=4)
864        assert_raises(np.AxisError, np.nanpercentile, d, q=5, axis=(0, 4))
865        assert_raises(ValueError, np.nanpercentile, d, q=5, axis=(1, 1))
866
867    def test_multiple_percentiles(self):
868        perc = [50, 100]
869        mat = np.ones((4, 3))
870        nan_mat = np.nan * mat
871        # For checking consistency in higher dimensional case
872        large_mat = np.ones((3, 4, 5))
873        large_mat[:, 0:2:4, :] = 0
874        large_mat[:, :, 3:] *= 2
875        for axis in [None, 0, 1]:
876            for keepdim in [False, True]:
877                with suppress_warnings() as sup:
878                    sup.filter(RuntimeWarning, "All-NaN slice encountered")
879                    val = np.percentile(mat, perc, axis=axis, keepdims=keepdim)
880                    nan_val = np.nanpercentile(nan_mat, perc, axis=axis,
881                                               keepdims=keepdim)
882                    assert_equal(nan_val.shape, val.shape)
883
884                    val = np.percentile(large_mat, perc, axis=axis,
885                                        keepdims=keepdim)
886                    nan_val = np.nanpercentile(large_mat, perc, axis=axis,
887                                               keepdims=keepdim)
888                    assert_equal(nan_val, val)
889
890        megamat = np.ones((3, 4, 5, 6))
891        assert_equal(np.nanpercentile(megamat, perc, axis=(1, 2)).shape, (2, 3, 6))
892
893
894class TestNanFunctions_Quantile:
895    # most of this is already tested by TestPercentile
896
897    def test_regression(self):
898        ar = np.arange(24).reshape(2, 3, 4).astype(float)
899        ar[0][1] = np.nan
900
901        assert_equal(np.nanquantile(ar, q=0.5), np.nanpercentile(ar, q=50))
902        assert_equal(np.nanquantile(ar, q=0.5, axis=0),
903                     np.nanpercentile(ar, q=50, axis=0))
904        assert_equal(np.nanquantile(ar, q=0.5, axis=1),
905                     np.nanpercentile(ar, q=50, axis=1))
906        assert_equal(np.nanquantile(ar, q=[0.5], axis=1),
907                     np.nanpercentile(ar, q=[50], axis=1))
908        assert_equal(np.nanquantile(ar, q=[0.25, 0.5, 0.75], axis=1),
909                     np.nanpercentile(ar, q=[25, 50, 75], axis=1))
910
911    def test_basic(self):
912        x = np.arange(8) * 0.5
913        assert_equal(np.nanquantile(x, 0), 0.)
914        assert_equal(np.nanquantile(x, 1), 3.5)
915        assert_equal(np.nanquantile(x, 0.5), 1.75)
916
917    def test_no_p_overwrite(self):
918        # this is worth retesting, because quantile does not make a copy
919        p0 = np.array([0, 0.75, 0.25, 0.5, 1.0])
920        p = p0.copy()
921        np.nanquantile(np.arange(100.), p, interpolation="midpoint")
922        assert_array_equal(p, p0)
923
924        p0 = p0.tolist()
925        p = p.tolist()
926        np.nanquantile(np.arange(100.), p, interpolation="midpoint")
927        assert_array_equal(p, p0)
928
929@pytest.mark.parametrize("arr, expected", [
930    # array of floats with some nans
931    (np.array([np.nan, 5.0, np.nan, np.inf]),
932     np.array([False, True, False, True])),
933    # int64 array that can't possibly have nans
934    (np.array([1, 5, 7, 9], dtype=np.int64),
935     True),
936    # bool array that can't possibly have nans
937    (np.array([False, True, False, True]),
938     True),
939    # 2-D complex array with nans
940    (np.array([[np.nan, 5.0],
941               [np.nan, np.inf]], dtype=np.complex64),
942     np.array([[False, True],
943               [False, True]])),
944    ])
945def test__nan_mask(arr, expected):
946    for out in [None, np.empty(arr.shape, dtype=np.bool_)]:
947        actual = _nan_mask(arr, out=out)
948        assert_equal(actual, expected)
949        # the above won't distinguish between True proper
950        # and an array of True values; we want True proper
951        # for types that can't possibly contain NaN
952        if type(expected) is not np.ndarray:
953            assert actual is True
954
955
956def test__replace_nan():
957    """ Test that _replace_nan returns the original array if there are no
958    NaNs, not a copy.
959    """
960    for dtype in [np.bool_, np.int32, np.int64]:
961        arr = np.array([0, 1], dtype=dtype)
962        result, mask = _replace_nan(arr, 0)
963        assert mask is None
964        # do not make a copy if there are no nans
965        assert result is arr
966
967    for dtype in [np.float32, np.float64]:
968        arr = np.array([0, 1], dtype=dtype)
969        result, mask = _replace_nan(arr, 2)
970        assert (mask == False).all()
971        # mask is not None, so we make a copy
972        assert result is not arr
973        assert_equal(result, arr)
974
975        arr_nan = np.array([0, 1, np.nan], dtype=dtype)
976        result_nan, mask_nan = _replace_nan(arr_nan, 2)
977        assert_equal(mask_nan, np.array([False, False, True]))
978        assert result_nan is not arr_nan
979        assert_equal(result_nan, np.array([0, 1, 2]))
980        assert np.isnan(arr_nan[-1])
981