1import numpy as np
2
3from nilearn.plotting.edge_detect import _edge_detect
4
5
6def test_edge_detect():
7    img = np.zeros((10, 10))
8    img[:5] = 1
9    _, edge_mask = _edge_detect(img)
10    np.testing.assert_almost_equal(img[4], 1)
11
12
13def test_edge_nan():
14    img = np.zeros((10, 10))
15    img[:5] = 1
16    img[0] = np.NaN
17    grad_mag, edge_mask = _edge_detect(img)
18    np.testing.assert_almost_equal(img[4], 1)
19    assert (grad_mag[0] > 2).all()
20