1 2import logging 3from dipy.workflows.workflow import Workflow 4from dipy.io.image import save_nifti, load_nifti 5import numpy as np 6from time import time 7from dipy.tracking import Streamlines 8from dipy.segment.mask import median_otsu 9from dipy.segment.bundles import RecoBundles 10from dipy.io.stateful_tractogram import Space, StatefulTractogram 11from dipy.io.streamline import load_tractogram, save_tractogram 12 13 14class MedianOtsuFlow(Workflow): 15 @classmethod 16 def get_short_name(cls): 17 return 'medotsu' 18 19 def run(self, input_files, save_masked=False, median_radius=2, numpass=5, 20 autocrop=False, vol_idx=None, dilate=None, out_dir='', 21 out_mask='brain_mask.nii.gz', out_masked='dwi_masked.nii.gz'): 22 """Workflow wrapping the median_otsu segmentation method. 23 24 Applies median_otsu segmentation on each file found by 'globing' 25 ``input_files`` and saves the results in a directory specified by 26 ``out_dir``. 27 28 Parameters 29 ---------- 30 input_files : string 31 Path to the input volumes. This path may contain wildcards to 32 process multiple inputs at once. 33 save_masked : bool, optional 34 Save mask. 35 median_radius : int, optional 36 Radius (in voxels) of the applied median filter. 37 numpass : int, optional 38 Number of pass of the median filter. 39 autocrop : bool, optional 40 If True, the masked input_volumes will also be cropped using the 41 bounding box defined by the masked data. For example, if diffusion 42 images are of 1x1x1 (mm^3) or higher resolution auto-cropping could 43 reduce their size in memory and speed up some of the analysis. 44 vol_idx : variable int, optional 45 1D array representing indices of ``axis=-1`` of a 4D 46 `input_volume`. From the command line use something like 47 `3 4 5 6`. From script use something like `[3, 4, 5, 6]`. This 48 input is required for 4D volumes. 49 dilate : int, optional 50 number of iterations for binary dilation. 51 out_dir : string, optional 52 Output directory. (default current directory) 53 out_mask : string, optional 54 Name of the mask volume to be saved. 55 out_masked : string, optional 56 Name of the masked volume to be saved. 57 """ 58 io_it = self.get_io_iterator() 59 if vol_idx is not None: 60 vol_idx = [int(idx) for idx in vol_idx] 61 62 for fpath, mask_out_path, masked_out_path in io_it: 63 logging.info('Applying median_otsu segmentation on {0}'. 64 format(fpath)) 65 66 data, affine, img = load_nifti(fpath, return_img=True) 67 68 masked_volume, mask_volume = median_otsu( 69 data, 70 vol_idx=vol_idx, 71 median_radius=median_radius, 72 numpass=numpass, 73 autocrop=autocrop, dilate=dilate) 74 75 save_nifti(mask_out_path, mask_volume.astype(np.float64), affine) 76 77 logging.info('Mask saved as {0}'.format(mask_out_path)) 78 79 if save_masked: 80 save_nifti(masked_out_path, masked_volume, affine, 81 img.header) 82 83 logging.info('Masked volume saved as {0}'. 84 format(masked_out_path)) 85 86 return io_it 87 88 89class RecoBundlesFlow(Workflow): 90 @classmethod 91 def get_short_name(cls): 92 return 'recobundles' 93 94 def run(self, streamline_files, model_bundle_files, 95 greater_than=50, less_than=1000000, 96 no_slr=False, clust_thr=15., 97 reduction_thr=15., 98 reduction_distance='mdf', 99 model_clust_thr=2.5, 100 pruning_thr=8., 101 pruning_distance='mdf', 102 slr_metric='symmetric', 103 slr_transform='similarity', 104 slr_matrix='small', 105 refine=False, r_reduction_thr=12., 106 r_pruning_thr=6., no_r_slr=False, 107 out_dir='', 108 out_recognized_transf='recognized.trk', 109 out_recognized_labels='labels.npy'): 110 """ Recognize bundles 111 112 Parameters 113 ---------- 114 streamline_files : string 115 The path of streamline files where you want to recognize bundles. 116 model_bundle_files : string 117 The path of model bundle files. 118 greater_than : int, optional 119 Keep streamlines that have length greater than 120 this value in mm. 121 less_than : int, optional 122 Keep streamlines have length less than this value 123 in mm. 124 no_slr : bool, optional 125 Don't enable local Streamline-based Linear 126 Registration. 127 clust_thr : float, optional 128 MDF distance threshold for all streamlines. 129 reduction_thr : float, optional 130 Reduce search space by (mm). 131 reduction_distance : string, optional 132 Reduction distance type can be mdf or mam. 133 model_clust_thr : float, optional 134 MDF distance threshold for the model bundles. 135 pruning_thr : float, optional 136 Pruning after matching. 137 pruning_distance : string, optional 138 Pruning distance type can be mdf or mam. 139 slr_metric : string, optional 140 Options are None, symmetric, asymmetric or diagonal. 141 slr_transform : string, optional 142 Transformation allowed. translation, rigid, similarity or scaling. 143 slr_matrix : string, optional 144 Options are 'nano', 'tiny', 'small', 'medium', 'large', 'huge'. 145 refine : bool, optional 146 Enable refine recognized bundle. 147 r_reduction_thr : float, optional 148 Refine reduce search space by (mm). 149 r_pruning_thr : float, optional 150 Refine pruning after matching. 151 no_r_slr : bool, optional 152 Don't enable Refine local Streamline-based Linear 153 Registration. 154 out_dir : string, optional 155 Output directory. (default current directory) 156 out_recognized_transf : string, optional 157 Recognized bundle in the space of the model bundle. 158 out_recognized_labels : string, optional 159 Indices of recognized bundle in the original tractogram. 160 161 References 162 ---------- 163 .. [Garyfallidis17] Garyfallidis et al. Recognition of white matter 164 bundles using local and global streamline-based registration and 165 clustering, Neuroimage, 2017. 166 167 .. [Chandio2020] Chandio, B.Q., Risacher, S.L., Pestilli, F., 168 Bullock, D., Yeh, FC., Koudoro, S., Rokem, A., Harezlak, J., and 169 Garyfallidis, E. Bundle analytics, a computational framework for 170 investigating the shapes and profiles of brain pathways across 171 populations. Sci Rep 10, 17149 (2020) 172 173 """ 174 slr = not no_slr 175 r_slr = not no_r_slr 176 177 bounds = [(-30, 30), (-30, 30), (-30, 30), 178 (-45, 45), (-45, 45), (-45, 45), 179 (0.8, 1.2), (0.8, 1.2), (0.8, 1.2)] 180 181 slr_matrix = slr_matrix.lower() 182 if slr_matrix == 'nano': 183 slr_select = (100, 100) 184 if slr_matrix == 'tiny': 185 slr_select = (250, 250) 186 if slr_matrix == 'small': 187 slr_select = (400, 400) 188 if slr_matrix == 'medium': 189 slr_select = (600, 600) 190 if slr_matrix == 'large': 191 slr_select = (800, 800) 192 if slr_matrix == 'huge': 193 slr_select = (1200, 1200) 194 195 slr_transform = slr_transform.lower() 196 if slr_transform == 'translation': 197 bounds = bounds[:3] 198 if slr_transform == 'rigid': 199 bounds = bounds[:6] 200 if slr_transform == 'similarity': 201 bounds = bounds[:7] 202 if slr_transform == 'scaling': 203 bounds = bounds[:9] 204 205 logging.info('### RecoBundles ###') 206 207 io_it = self.get_io_iterator() 208 209 t = time() 210 logging.info(streamline_files) 211 input_obj = load_tractogram(streamline_files, 'same', 212 bbox_valid_check=False) 213 streamlines = input_obj.streamlines 214 215 logging.info(' Loading time %0.3f sec' % (time() - t,)) 216 217 rb = RecoBundles(streamlines, greater_than=greater_than, 218 less_than=less_than) 219 220 for _, mb, out_rec, out_labels in io_it: 221 t = time() 222 logging.info(mb) 223 model_bundle = load_tractogram(mb, 'same', 224 bbox_valid_check=False).streamlines 225 logging.info(' Loading time %0.3f sec' % (time() - t,)) 226 logging.info("model file = ") 227 logging.info(mb) 228 229 recognized_bundle, labels = \ 230 rb.recognize( 231 model_bundle, 232 model_clust_thr=model_clust_thr, 233 reduction_thr=reduction_thr, 234 reduction_distance=reduction_distance, 235 pruning_thr=pruning_thr, 236 pruning_distance=pruning_distance, 237 slr=slr, 238 slr_metric=slr_metric, 239 slr_x0=slr_transform, 240 slr_bounds=bounds, 241 slr_select=slr_select, 242 slr_method='L-BFGS-B') 243 244 if refine: 245 246 if len(recognized_bundle) > 1: 247 248 # affine 249 x0 = np.array([0, 0, 0, 0, 0, 0, 1., 1., 1, 0, 0, 0]) 250 affine_bounds = [(-30, 30), (-30, 30), (-30, 30), 251 (-45, 45), (-45, 45), (-45, 45), 252 (0.8, 1.2), (0.8, 1.2), (0.8, 1.2), 253 (-10, 10), (-10, 10), (-10, 10)] 254 255 recognized_bundle, labels = \ 256 rb.refine( 257 model_bundle, 258 recognized_bundle, 259 model_clust_thr=model_clust_thr, 260 reduction_thr=r_reduction_thr, 261 reduction_distance=reduction_distance, 262 pruning_thr=r_pruning_thr, 263 pruning_distance=pruning_distance, 264 slr=r_slr, 265 slr_metric=slr_metric, 266 slr_x0=x0, 267 slr_bounds=affine_bounds, 268 slr_select=slr_select, 269 slr_method='L-BFGS-B') 270 271 if len(labels) > 0: 272 ba, bmd = rb.evaluate_results( 273 model_bundle, recognized_bundle, 274 slr_select) 275 276 logging.info("Bundle adjacency Metric {0}".format(ba)) 277 logging.info("Bundle Min Distance Metric {0}".format(bmd)) 278 279 new_tractogram = StatefulTractogram(recognized_bundle, 280 streamline_files, Space.RASMM) 281 save_tractogram(new_tractogram, out_rec, bbox_valid_check=False) 282 logging.info('Saving output files ...') 283 np.save(out_labels, np.array(labels)) 284 logging.info(out_rec) 285 logging.info(out_labels) 286 287 288class LabelsBundlesFlow(Workflow): 289 @classmethod 290 def get_short_name(cls): 291 return 'labelsbundles' 292 293 def run(self, streamline_files, labels_files, 294 out_dir='', 295 out_bundle='recognized_orig.trk'): 296 """ Extract bundles using existing indices (labels) 297 298 Parameters 299 ---------- 300 streamline_files : string 301 The path of streamline files where you want to recognize bundles. 302 labels_files : string 303 The path of model bundle files. 304 out_dir : string, optional 305 Output directory. (default current directory) 306 out_bundle : string, optional 307 Recognized bundle in the space of the model bundle. 308 309 References 310 ---------- 311 .. [Garyfallidis17] Garyfallidis et al. Recognition of white matter 312 bundles using local and global streamline-based registration and 313 clustering, Neuroimage, 2017. 314 315 """ 316 logging.info('### Labels to Bundles ###') 317 318 io_it = self.get_io_iterator() 319 for f_steamlines, f_labels, out_bundle in io_it: 320 321 logging.info(f_steamlines) 322 sft = load_tractogram(f_steamlines, 'same', 323 bbox_valid_check=False) 324 streamlines = sft.streamlines 325 326 logging.info(f_labels) 327 location = np.load(f_labels) 328 if len(location) < 1 : 329 bundle = Streamlines([]) 330 else: 331 bundle = streamlines[location] 332 333 logging.info('Saving output files ...') 334 new_sft = StatefulTractogram(bundle, sft, Space.RASMM) 335 save_tractogram(new_sft, out_bundle, bbox_valid_check=False) 336 logging.info(out_bundle) 337