1import warnings
2
3import numpy as np
4from scipy.ndimage import generate_binary_structure, binary_dilation
5from scipy.ndimage.filters import median_filter
6
7from dipy.segment.mask import (otsu, bounding_box, crop, applymask,
8                               multi_median, median_otsu)
9
10from numpy.testing import (assert_equal,
11                           assert_almost_equal,
12                           assert_raises,
13                           run_module_suite)
14from dipy.data import get_fnames
15from dipy.io.image import load_nifti_data
16
17
18def test_mask():
19    vol = np.zeros((30, 30, 30))
20    vol[15, 15, 15] = 1
21    struct = generate_binary_structure(3, 1)
22    voln = binary_dilation(vol, structure=struct, iterations=4).astype('f4')
23    initial = np.sum(voln > 0)
24    mask = voln.copy()
25    thresh = otsu(mask)
26    mask = mask > thresh
27    initial_otsu = np.sum(mask > 0)
28    assert_equal(initial_otsu, initial)
29
30    mins, maxs = bounding_box(mask)
31    voln_crop = crop(mask, mins, maxs)
32    initial_crop = np.sum(voln_crop > 0)
33    assert_equal(initial_crop, initial)
34
35    applymask(voln, mask)
36    final = np.sum(voln > 0)
37    assert_equal(final, initial)
38
39    # Test multi_median.
40    img = np.arange(25).reshape(5, 5)
41    img_copy = img.copy()
42    medianradius = 2
43    median_test = multi_median(img, medianradius, 3)
44    assert_equal(img, img_copy)
45
46    medarr = np.ones_like(img.shape) * ((medianradius * 2) + 1)
47    median_control = median_filter(img, medarr)
48    median_control = median_filter(median_control, medarr)
49    median_control = median_filter(median_control, medarr)
50    assert_equal(median_test, median_control)
51
52
53def test_bounding_box():
54    vol = np.zeros((100, 100, 50), dtype=int)
55
56    # Check the more usual case
57    vol[10:90, 11:40, 5:33] = 3
58    mins, maxs = bounding_box(vol)
59    assert_equal(mins, [10, 11, 5])
60    assert_equal(maxs, [90, 40, 33])
61
62    # Check a 2d case
63    mins, maxs = bounding_box(vol[10])
64    assert_equal(mins, [11, 5])
65    assert_equal(maxs, [40, 33])
66
67    vol[:] = 0
68    with warnings.catch_warnings(record=True) as w:
69        warnings.simplefilter("always")
70        # Trigger a warning.
71        num_warns = len(w)
72        mins, maxs = bounding_box(vol)
73        # Assert number of warnings has gone up by 1
74        assert_equal(len(w), num_warns + 1)
75
76        # Check that an empty array returns zeros for both min & max
77        assert_equal(mins, [0, 0, 0])
78        assert_equal(maxs, [0, 0, 0])
79
80        # Check the 2d case
81        mins, maxs = bounding_box(vol[0])
82        assert_equal(len(w), num_warns + 2)
83        assert_equal(mins, [0, 0])
84        assert_equal(maxs, [0, 0])
85
86
87def test_median_otsu():
88    fname = get_fnames('S0_10')
89    data = load_nifti_data(fname)
90    data = np.squeeze(data.astype('f8'))
91    dummy_mask = data > data.mean()
92    data_masked, mask = median_otsu(data, median_radius=3, numpass=2,
93                                    autocrop=False, vol_idx=None,
94                                    dilate=None)
95    assert_equal(mask.sum() < dummy_mask.sum(), True)
96    data2 = np.zeros(data.shape + (2,))
97    data2[..., 0] = data
98    data2[..., 1] = data
99
100    data2_masked, mask2 = median_otsu(data2, median_radius=3, numpass=2,
101                                      autocrop=False, vol_idx=[0, 1],
102                                      dilate=None)
103    assert_almost_equal(mask.sum(), mask2.sum())
104
105    _, mask3 = median_otsu(data2, median_radius=3, numpass=2,
106                           autocrop=False, vol_idx=[0, 1],
107                           dilate=1)
108    assert_equal(mask2.sum() < mask3.sum(), True)
109
110    _, mask4 = median_otsu(data2, median_radius=3, numpass=2,
111                           autocrop=False, vol_idx=[0, 1],
112                           dilate=2)
113    assert_equal(mask3.sum() < mask4.sum(), True)
114
115    # For 4D volumes, can't call without vol_idx input:
116    assert_raises(ValueError, median_otsu, data2)
117
118
119if __name__ == '__main__':
120    run_module_suite()
121