1import warnings 2 3import numpy as np 4from scipy.ndimage import generate_binary_structure, binary_dilation 5from scipy.ndimage.filters import median_filter 6 7from dipy.segment.mask import (otsu, bounding_box, crop, applymask, 8 multi_median, median_otsu) 9 10from numpy.testing import (assert_equal, 11 assert_almost_equal, 12 assert_raises, 13 run_module_suite) 14from dipy.data import get_fnames 15from dipy.io.image import load_nifti_data 16 17 18def test_mask(): 19 vol = np.zeros((30, 30, 30)) 20 vol[15, 15, 15] = 1 21 struct = generate_binary_structure(3, 1) 22 voln = binary_dilation(vol, structure=struct, iterations=4).astype('f4') 23 initial = np.sum(voln > 0) 24 mask = voln.copy() 25 thresh = otsu(mask) 26 mask = mask > thresh 27 initial_otsu = np.sum(mask > 0) 28 assert_equal(initial_otsu, initial) 29 30 mins, maxs = bounding_box(mask) 31 voln_crop = crop(mask, mins, maxs) 32 initial_crop = np.sum(voln_crop > 0) 33 assert_equal(initial_crop, initial) 34 35 applymask(voln, mask) 36 final = np.sum(voln > 0) 37 assert_equal(final, initial) 38 39 # Test multi_median. 40 img = np.arange(25).reshape(5, 5) 41 img_copy = img.copy() 42 medianradius = 2 43 median_test = multi_median(img, medianradius, 3) 44 assert_equal(img, img_copy) 45 46 medarr = np.ones_like(img.shape) * ((medianradius * 2) + 1) 47 median_control = median_filter(img, medarr) 48 median_control = median_filter(median_control, medarr) 49 median_control = median_filter(median_control, medarr) 50 assert_equal(median_test, median_control) 51 52 53def test_bounding_box(): 54 vol = np.zeros((100, 100, 50), dtype=int) 55 56 # Check the more usual case 57 vol[10:90, 11:40, 5:33] = 3 58 mins, maxs = bounding_box(vol) 59 assert_equal(mins, [10, 11, 5]) 60 assert_equal(maxs, [90, 40, 33]) 61 62 # Check a 2d case 63 mins, maxs = bounding_box(vol[10]) 64 assert_equal(mins, [11, 5]) 65 assert_equal(maxs, [40, 33]) 66 67 vol[:] = 0 68 with warnings.catch_warnings(record=True) as w: 69 warnings.simplefilter("always") 70 # Trigger a warning. 71 num_warns = len(w) 72 mins, maxs = bounding_box(vol) 73 # Assert number of warnings has gone up by 1 74 assert_equal(len(w), num_warns + 1) 75 76 # Check that an empty array returns zeros for both min & max 77 assert_equal(mins, [0, 0, 0]) 78 assert_equal(maxs, [0, 0, 0]) 79 80 # Check the 2d case 81 mins, maxs = bounding_box(vol[0]) 82 assert_equal(len(w), num_warns + 2) 83 assert_equal(mins, [0, 0]) 84 assert_equal(maxs, [0, 0]) 85 86 87def test_median_otsu(): 88 fname = get_fnames('S0_10') 89 data = load_nifti_data(fname) 90 data = np.squeeze(data.astype('f8')) 91 dummy_mask = data > data.mean() 92 data_masked, mask = median_otsu(data, median_radius=3, numpass=2, 93 autocrop=False, vol_idx=None, 94 dilate=None) 95 assert_equal(mask.sum() < dummy_mask.sum(), True) 96 data2 = np.zeros(data.shape + (2,)) 97 data2[..., 0] = data 98 data2[..., 1] = data 99 100 data2_masked, mask2 = median_otsu(data2, median_radius=3, numpass=2, 101 autocrop=False, vol_idx=[0, 1], 102 dilate=None) 103 assert_almost_equal(mask.sum(), mask2.sum()) 104 105 _, mask3 = median_otsu(data2, median_radius=3, numpass=2, 106 autocrop=False, vol_idx=[0, 1], 107 dilate=1) 108 assert_equal(mask2.sum() < mask3.sum(), True) 109 110 _, mask4 = median_otsu(data2, median_radius=3, numpass=2, 111 autocrop=False, vol_idx=[0, 1], 112 dilate=2) 113 assert_equal(mask3.sum() < mask4.sum(), True) 114 115 # For 4D volumes, can't call without vol_idx input: 116 assert_raises(ValueError, median_otsu, data2) 117 118 119if __name__ == '__main__': 120 run_module_suite() 121