1
2import logging
3from dipy.workflows.workflow import Workflow
4from dipy.io.image import save_nifti, load_nifti
5import numpy as np
6from time import time
7from dipy.tracking import Streamlines
8from dipy.segment.mask import median_otsu
9from dipy.segment.bundles import RecoBundles
10from dipy.io.stateful_tractogram import Space, StatefulTractogram
11from dipy.io.streamline import load_tractogram, save_tractogram
12
13
14class MedianOtsuFlow(Workflow):
15    @classmethod
16    def get_short_name(cls):
17        return 'medotsu'
18
19    def run(self, input_files, save_masked=False, median_radius=2, numpass=5,
20            autocrop=False, vol_idx=None, dilate=None, out_dir='',
21            out_mask='brain_mask.nii.gz', out_masked='dwi_masked.nii.gz'):
22        """Workflow wrapping the median_otsu segmentation method.
23
24        Applies median_otsu segmentation on each file found by 'globing'
25        ``input_files`` and saves the results in a directory specified by
26        ``out_dir``.
27
28        Parameters
29        ----------
30        input_files : string
31            Path to the input volumes. This path may contain wildcards to
32            process multiple inputs at once.
33        save_masked : bool, optional
34            Save mask.
35        median_radius : int, optional
36            Radius (in voxels) of the applied median filter.
37        numpass : int, optional
38            Number of pass of the median filter.
39        autocrop : bool, optional
40            If True, the masked input_volumes will also be cropped using the
41            bounding box defined by the masked data. For example, if diffusion
42            images are of 1x1x1 (mm^3) or higher resolution auto-cropping could
43            reduce their size in memory and speed up some of the analysis.
44        vol_idx : variable int, optional
45            1D array representing indices of ``axis=-1`` of a 4D
46            `input_volume`. From the command line use something like
47            `3 4 5 6`. From script use something like `[3, 4, 5, 6]`. This
48            input is required for 4D volumes.
49        dilate : int, optional
50            number of iterations for binary dilation.
51        out_dir : string, optional
52            Output directory. (default current directory)
53        out_mask : string, optional
54            Name of the mask volume to be saved.
55        out_masked : string, optional
56            Name of the masked volume to be saved.
57        """
58        io_it = self.get_io_iterator()
59        if vol_idx is not None:
60            vol_idx = [int(idx) for idx in vol_idx]
61
62        for fpath, mask_out_path, masked_out_path in io_it:
63            logging.info('Applying median_otsu segmentation on {0}'.
64                         format(fpath))
65
66            data, affine, img = load_nifti(fpath, return_img=True)
67
68            masked_volume, mask_volume = median_otsu(
69                data,
70                vol_idx=vol_idx,
71                median_radius=median_radius,
72                numpass=numpass,
73                autocrop=autocrop, dilate=dilate)
74
75            save_nifti(mask_out_path, mask_volume.astype(np.float64), affine)
76
77            logging.info('Mask saved as {0}'.format(mask_out_path))
78
79            if save_masked:
80                save_nifti(masked_out_path, masked_volume, affine,
81                           img.header)
82
83                logging.info('Masked volume saved as {0}'.
84                             format(masked_out_path))
85
86        return io_it
87
88
89class RecoBundlesFlow(Workflow):
90    @classmethod
91    def get_short_name(cls):
92        return 'recobundles'
93
94    def run(self, streamline_files, model_bundle_files,
95            greater_than=50, less_than=1000000,
96            no_slr=False, clust_thr=15.,
97            reduction_thr=15.,
98            reduction_distance='mdf',
99            model_clust_thr=2.5,
100            pruning_thr=8.,
101            pruning_distance='mdf',
102            slr_metric='symmetric',
103            slr_transform='similarity',
104            slr_matrix='small',
105            refine=False, r_reduction_thr=12.,
106            r_pruning_thr=6., no_r_slr=False,
107            out_dir='',
108            out_recognized_transf='recognized.trk',
109            out_recognized_labels='labels.npy'):
110        """ Recognize bundles
111
112        Parameters
113        ----------
114        streamline_files : string
115            The path of streamline files where you want to recognize bundles.
116        model_bundle_files : string
117            The path of model bundle files.
118        greater_than : int, optional
119            Keep streamlines that have length greater than
120            this value in mm.
121        less_than : int, optional
122            Keep streamlines have length less than this value
123            in mm.
124        no_slr : bool, optional
125            Don't enable local Streamline-based Linear
126            Registration.
127        clust_thr : float, optional
128            MDF distance threshold for all streamlines.
129        reduction_thr : float, optional
130            Reduce search space by (mm).
131        reduction_distance : string, optional
132            Reduction distance type can be mdf or mam.
133        model_clust_thr : float, optional
134            MDF distance threshold for the model bundles.
135        pruning_thr : float, optional
136            Pruning after matching.
137        pruning_distance : string, optional
138            Pruning distance type can be mdf or mam.
139        slr_metric : string, optional
140            Options are None, symmetric, asymmetric or diagonal.
141        slr_transform : string, optional
142            Transformation allowed. translation, rigid, similarity or scaling.
143        slr_matrix : string, optional
144            Options are 'nano', 'tiny', 'small', 'medium', 'large', 'huge'.
145        refine : bool, optional
146            Enable refine recognized bundle.
147        r_reduction_thr : float, optional
148            Refine reduce search space by (mm).
149        r_pruning_thr : float, optional
150            Refine pruning after matching.
151        no_r_slr : bool, optional
152            Don't enable Refine local Streamline-based Linear
153            Registration.
154        out_dir : string, optional
155            Output directory. (default current directory)
156        out_recognized_transf : string, optional
157            Recognized bundle in the space of the model bundle.
158        out_recognized_labels : string, optional
159            Indices of recognized bundle in the original tractogram.
160
161        References
162        ----------
163        .. [Garyfallidis17] Garyfallidis et al. Recognition of white matter
164         bundles using local and global streamline-based registration and
165         clustering, Neuroimage, 2017.
166
167        .. [Chandio2020] Chandio, B.Q., Risacher, S.L., Pestilli, F.,
168        Bullock, D., Yeh, FC., Koudoro, S., Rokem, A., Harezlak, J., and
169        Garyfallidis, E. Bundle analytics, a computational framework for
170        investigating the shapes and profiles of brain pathways across
171        populations. Sci Rep 10, 17149 (2020)
172
173        """
174        slr = not no_slr
175        r_slr = not no_r_slr
176
177        bounds = [(-30, 30), (-30, 30), (-30, 30),
178                  (-45, 45), (-45, 45), (-45, 45),
179                  (0.8, 1.2), (0.8, 1.2), (0.8, 1.2)]
180
181        slr_matrix = slr_matrix.lower()
182        if slr_matrix == 'nano':
183            slr_select = (100, 100)
184        if slr_matrix == 'tiny':
185            slr_select = (250, 250)
186        if slr_matrix == 'small':
187            slr_select = (400, 400)
188        if slr_matrix == 'medium':
189            slr_select = (600, 600)
190        if slr_matrix == 'large':
191            slr_select = (800, 800)
192        if slr_matrix == 'huge':
193            slr_select = (1200, 1200)
194
195        slr_transform = slr_transform.lower()
196        if slr_transform == 'translation':
197            bounds = bounds[:3]
198        if slr_transform == 'rigid':
199            bounds = bounds[:6]
200        if slr_transform == 'similarity':
201            bounds = bounds[:7]
202        if slr_transform == 'scaling':
203            bounds = bounds[:9]
204
205        logging.info('### RecoBundles ###')
206
207        io_it = self.get_io_iterator()
208
209        t = time()
210        logging.info(streamline_files)
211        input_obj = load_tractogram(streamline_files, 'same',
212                                    bbox_valid_check=False)
213        streamlines = input_obj.streamlines
214
215        logging.info(' Loading time %0.3f sec' % (time() - t,))
216
217        rb = RecoBundles(streamlines, greater_than=greater_than,
218                         less_than=less_than)
219
220        for _, mb, out_rec, out_labels in io_it:
221            t = time()
222            logging.info(mb)
223            model_bundle = load_tractogram(mb, 'same',
224                                           bbox_valid_check=False).streamlines
225            logging.info(' Loading time %0.3f sec' % (time() - t,))
226            logging.info("model file = ")
227            logging.info(mb)
228
229            recognized_bundle, labels = \
230                rb.recognize(
231                    model_bundle,
232                    model_clust_thr=model_clust_thr,
233                    reduction_thr=reduction_thr,
234                    reduction_distance=reduction_distance,
235                    pruning_thr=pruning_thr,
236                    pruning_distance=pruning_distance,
237                    slr=slr,
238                    slr_metric=slr_metric,
239                    slr_x0=slr_transform,
240                    slr_bounds=bounds,
241                    slr_select=slr_select,
242                    slr_method='L-BFGS-B')
243
244            if refine:
245
246                if len(recognized_bundle) > 1:
247
248                    # affine
249                    x0 = np.array([0, 0, 0, 0, 0, 0, 1., 1., 1, 0, 0, 0])
250                    affine_bounds = [(-30, 30), (-30, 30), (-30, 30),
251                                     (-45, 45), (-45, 45), (-45, 45),
252                                     (0.8, 1.2), (0.8, 1.2), (0.8, 1.2),
253                                     (-10, 10), (-10, 10), (-10, 10)]
254
255                    recognized_bundle, labels = \
256                        rb.refine(
257                            model_bundle,
258                            recognized_bundle,
259                            model_clust_thr=model_clust_thr,
260                            reduction_thr=r_reduction_thr,
261                            reduction_distance=reduction_distance,
262                            pruning_thr=r_pruning_thr,
263                            pruning_distance=pruning_distance,
264                            slr=r_slr,
265                            slr_metric=slr_metric,
266                            slr_x0=x0,
267                            slr_bounds=affine_bounds,
268                            slr_select=slr_select,
269                            slr_method='L-BFGS-B')
270
271            if len(labels) > 0:
272                ba, bmd = rb.evaluate_results(
273                    model_bundle, recognized_bundle,
274                    slr_select)
275
276                logging.info("Bundle adjacency Metric {0}".format(ba))
277                logging.info("Bundle Min Distance Metric {0}".format(bmd))
278
279            new_tractogram = StatefulTractogram(recognized_bundle,
280                                                streamline_files, Space.RASMM)
281            save_tractogram(new_tractogram, out_rec, bbox_valid_check=False)
282            logging.info('Saving output files ...')
283            np.save(out_labels, np.array(labels))
284            logging.info(out_rec)
285            logging.info(out_labels)
286
287
288class LabelsBundlesFlow(Workflow):
289    @classmethod
290    def get_short_name(cls):
291        return 'labelsbundles'
292
293    def run(self, streamline_files, labels_files,
294            out_dir='',
295            out_bundle='recognized_orig.trk'):
296        """ Extract bundles using existing indices (labels)
297
298        Parameters
299        ----------
300        streamline_files : string
301            The path of streamline files where you want to recognize bundles.
302        labels_files : string
303            The path of model bundle files.
304        out_dir : string, optional
305            Output directory. (default current directory)
306        out_bundle : string, optional
307            Recognized bundle in the space of the model bundle.
308
309        References
310        ----------
311        .. [Garyfallidis17] Garyfallidis et al. Recognition of white matter
312         bundles using local and global streamline-based registration and
313         clustering, Neuroimage, 2017.
314
315        """
316        logging.info('### Labels to Bundles ###')
317
318        io_it = self.get_io_iterator()
319        for f_steamlines, f_labels, out_bundle in io_it:
320
321            logging.info(f_steamlines)
322            sft = load_tractogram(f_steamlines, 'same',
323                                  bbox_valid_check=False)
324            streamlines = sft.streamlines
325
326            logging.info(f_labels)
327            location = np.load(f_labels)
328            if len(location) < 1 :
329                bundle = Streamlines([])
330            else:
331                bundle = streamlines[location]
332
333            logging.info('Saving output files ...')
334            new_sft = StatefulTractogram(bundle, sft, Space.RASMM)
335            save_tractogram(new_sft, out_bundle, bbox_valid_check=False)
336            logging.info(out_bundle)
337