1
2from warnings import warn
3
4import numpy as np
5
6from dipy.reconst.dti import fractional_anisotropy, color_fa
7
8from scipy.ndimage.filters import median_filter
9try:
10    from skimage.filters import threshold_otsu as otsu
11except Exception:
12    from dipy.segment.threshold import otsu
13
14from scipy.ndimage import binary_dilation, generate_binary_structure
15
16
17def multi_median(input, median_radius, numpass):
18    """ Applies median filter multiple times on input data.
19
20    Parameters
21    ----------
22    input : ndarray
23        The input volume to apply filter on.
24    median_radius : int
25        Radius (in voxels) of the applied median filter
26    numpass: int
27        Number of pass of the median filter
28
29    Returns
30    -------
31    input : ndarray
32        Filtered input volume.
33    """
34    # Array representing the size of the median window in each dimension.
35    medarr = np.ones_like(input.shape) * ((median_radius * 2) + 1)
36
37    if numpass > 1:
38        # ensure the input array is not modified
39        input = input.copy()
40
41    # Multi pass
42    output = np.empty_like(input)
43    for i in range(0, numpass):
44        median_filter(input, medarr, output=output)
45        input, output = output, input
46    return input
47
48
49def applymask(vol, mask):
50    """ Mask vol with mask.
51
52    Parameters
53    ----------
54    vol : ndarray
55        Array with $V$ dimensions
56    mask : ndarray
57        Binary mask.  Has $M$ dimensions where $M <= V$. When $M < V$, we
58        append $V - M$ dimensions with axis length 1 to `mask` so that `mask`
59        will broadcast against `vol`.  In the typical case `vol` can be 4D,
60        `mask` can be 3D, and we append a 1 to the mask shape which (via numpy
61        broadcasting) has the effect of appling the 3D mask to each 3D slice in
62        `vol` (``vol[..., 0]`` to ``vol[..., -1``).
63
64    Returns
65    -------
66    masked_vol : ndarray
67        `vol` multiplied by `mask` where `mask` may have been extended to match
68        extra dimensions in `vol`
69    """
70    mask = mask.reshape(mask.shape + (vol.ndim - mask.ndim) * (1,))
71    return vol * mask
72
73
74def bounding_box(vol):
75    """Compute the bounding box of nonzero intensity voxels in the volume.
76
77    Parameters
78    ----------
79    vol : ndarray
80        Volume to compute bounding box on.
81
82    Returns
83    -------
84    npmins : list
85        Array containg minimum index of each dimension
86    npmaxs : list
87        Array containg maximum index of each dimension
88    """
89    # Find bounds on first dimension
90    temp = vol
91    for i in range(vol.ndim - 1):
92        temp = temp.any(-1)
93    mins = [temp.argmax()]
94    maxs = [len(temp) - temp[::-1].argmax()]
95    # Check that vol is not all 0
96    if mins[0] == 0 and temp[0] == 0:
97        warn('No data found in volume to bound. Returning empty bounding box.')
98        return [0] * vol.ndim, [0] * vol.ndim
99    # Find bounds on remaining dimensions
100    if vol.ndim > 1:
101        a, b = bounding_box(vol.any(0))
102        mins.extend(a)
103        maxs.extend(b)
104    return mins, maxs
105
106
107def crop(vol, mins, maxs):
108    """Crops the input volume.
109
110    Parameters
111    ----------
112    vol : ndarray
113        Volume to crop.
114    mins : array
115        Array containg minimum index of each dimension.
116    maxs : array
117        Array containg maximum index of each dimension.
118
119    Returns
120    -------
121    vol : ndarray
122        The cropped volume.
123    """
124    return vol[tuple(slice(i, j) for i, j in zip(mins, maxs))]
125
126
127def median_otsu(input_volume, vol_idx=None, median_radius=4, numpass=4,
128                autocrop=False, dilate=None):
129    """Simple brain extraction tool method for images from DWI data.
130
131    It uses a median filter smoothing of the input_volumes `vol_idx` and an
132    automatic histogram Otsu thresholding technique, hence the name
133    *median_otsu*.
134
135    This function is inspired from Mrtrix's bet which has default values
136    ``median_radius=3``, ``numpass=2``. However, from tests on multiple 1.5T
137    and 3T data     from GE, Philips, Siemens, the most robust choice is
138    ``median_radius=4``, ``numpass=4``.
139
140    Parameters
141    ----------
142    input_volume : ndarray
143        3D or 4D array of the brain volume.
144    vol_idx : None or array, optional.
145        1D array representing indices of ``axis=3`` of a 4D `input_volume`.
146        None is only an acceptable input if ``input_volume`` is 3D.
147    median_radius : int
148        Radius (in voxels) of the applied median filter (default: 4).
149    numpass: int
150        Number of pass of the median filter (default: 4).
151    autocrop: bool, optional
152        if True, the masked input_volume will also be cropped using the
153        bounding box defined by the masked data. Should be on if DWI is
154        upsampled to 1x1x1 resolution. (default: False).
155
156    dilate : None or int, optional
157        number of iterations for binary dilation
158
159    Returns
160    -------
161    maskedvolume : ndarray
162        Masked input_volume
163    mask : 3D ndarray
164        The binary brain mask
165
166    Notes
167    -----
168    Copyright (C) 2011, the scikit-image team
169    All rights reserved.
170
171    Redistribution and use in source and binary forms, with or without
172    modification, are permitted provided that the following conditions are
173    met:
174
175     1. Redistributions of source code must retain the above copyright
176        notice, this list of conditions and the following disclaimer.
177     2. Redistributions in binary form must reproduce the above copyright
178        notice, this list of conditions and the following disclaimer in
179        the documentation and/or other materials provided with the
180        distribution.
181     3. Neither the name of skimage nor the names of its contributors may be
182        used to endorse or promote products derived from this software without
183        specific prior written permission.
184
185    THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
186    IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
187    WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
188    DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT,
189    INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
190    (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
191    SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
192    HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
193    STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
194    IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
195    POSSIBILITY OF SUCH DAMAGE.
196    """
197    if len(input_volume.shape) == 4:
198        if vol_idx is not None:
199            b0vol = np.mean(input_volume[..., tuple(vol_idx)], axis=3)
200        else:
201            raise ValueError("For 4D images, must provide vol_idx input")
202    else:
203        b0vol = input_volume
204    # Make a mask using a multiple pass median filter and histogram
205    # thresholding.
206    mask = multi_median(b0vol, median_radius, numpass)
207    thresh = otsu(mask)
208    mask = mask > thresh
209
210    if dilate is not None:
211        cross = generate_binary_structure(3, 1)
212        mask = binary_dilation(mask, cross, iterations=dilate)
213
214    # Auto crop the volumes using the mask as input_volume for bounding box
215    # computing.
216    if autocrop:
217        mins, maxs = bounding_box(mask)
218        mask = crop(mask, mins, maxs)
219        croppedvolume = crop(input_volume, mins, maxs)
220        maskedvolume = applymask(croppedvolume, mask)
221    else:
222        maskedvolume = applymask(input_volume, mask)
223    return maskedvolume, mask
224
225
226def segment_from_cfa(tensor_fit, roi, threshold, return_cfa=False):
227    """
228    Segment the cfa inside roi using the values from threshold as bounds.
229
230    Parameters
231    -------------
232    tensor_fit : TensorFit object
233        TensorFit object
234
235    roi : ndarray
236        A binary mask, which contains the bounding box for the segmentation.
237
238    threshold : array-like
239        An iterable that defines the min and max values to use for the
240        thresholding.
241        The values are specified as (R_min, R_max, G_min, G_max, B_min, B_max)
242
243    return_cfa : bool, optional
244        If True, the cfa is also returned.
245
246    Returns
247    ----------
248    mask : ndarray
249        Binary mask of the segmentation.
250
251    cfa : ndarray, optional
252        Array with shape = (..., 3), where ... is the shape of tensor_fit.
253        The color fractional anisotropy, ordered as a nd array with the last
254        dimension of size 3 for the R, G and B channels.
255    """
256
257    FA = fractional_anisotropy(tensor_fit.evals)
258    FA[np.isnan(FA)] = 0
259    FA = np.clip(FA, 0, 1)  # Clamp the FA to remove degenerate tensors
260
261    cfa = color_fa(FA, tensor_fit.evecs)
262    roi = np.asarray(roi, dtype=bool)
263
264    include = ((cfa >= threshold[0::2]) &
265               (cfa <= threshold[1::2]) &
266               roi[..., None])
267    mask = np.all(include, axis=-1)
268
269    if return_cfa:
270        return mask, cfa
271
272    return mask
273
274
275def clean_cc_mask(mask):
276    """
277    Cleans a segmentation of the corpus callosum so no random pixels
278    are included.
279
280    Parameters
281    ----------
282    mask : ndarray
283        Binary mask of the coarse segmentation.
284
285    Returns
286    -------
287    new_cc_mask : ndarray
288        Binary mask of the cleaned segmentation.
289    """
290
291    from scipy.ndimage.measurements import label
292
293    new_cc_mask = np.zeros(mask.shape)
294
295    # Flood fill algorithm to find contiguous regions.
296    labels, numL = label(mask)
297
298    volumes = [len(labels[np.where(labels == l_idx+1)])
299               for l_idx in np.arange(numL)]
300    biggest_vol = np.arange(numL)[np.where(volumes == np.max(volumes))] + 1
301    new_cc_mask[np.where(labels == biggest_vol)] = 1
302
303    return new_cc_mask
304