1
2import logging
3import numpy as np
4import os
5import json
6import warnings
7from time import time
8from scipy.ndimage.morphology import binary_dilation
9from dipy.utils.optpkg import optional_package
10from dipy.io import read_bvals_bvecs
11from dipy.io.image import load_nifti, save_nifti
12from dipy.core.gradients import gradient_table
13from dipy.reconst.dti import TensorModel
14from dipy.io.peaks import load_peaks
15from dipy.io.stateful_tractogram import Space, StatefulTractogram
16from dipy.io.streamline import load_tractogram, save_tractogram
17from dipy.segment.mask import segment_from_cfa
18from dipy.segment.mask import bounding_box
19# from dipy.io.streamline import load_trk, save_trk
20from dipy.tracking.streamline import transform_streamlines
21from glob import glob
22from dipy.workflows.workflow import Workflow
23from dipy.segment.bundles import bundle_shape_similarity
24from dipy.stats.analysis import assignment_map
25from dipy.stats.analysis import anatomical_measures
26from dipy.stats.analysis import peak_values
27
28pd, have_pd, _ = optional_package("pandas")
29smf, have_smf, _ = optional_package("statsmodels")
30tables, have_tables, _ = optional_package("tables")
31matplt, have_matplotlib, _ = optional_package("matplotlib")
32
33if have_pd:
34    import pandas as pd
35
36if have_smf:
37    import statsmodels.formula.api as smf
38
39if have_matplotlib:
40    import matplotlib as matplt
41    import matplotlib.pyplot as plt
42
43
44class SNRinCCFlow(Workflow):
45
46    @classmethod
47    def get_short_name(cls):
48        return 'snrincc'
49
50    def run(self, data_files, bvals_files, bvecs_files, mask_file,
51            bbox_threshold=[0.6, 1, 0, 0.1, 0, 0.1], out_dir='',
52            out_file='product.json', out_mask_cc='cc.nii.gz',
53            out_mask_noise='mask_noise.nii.gz'):
54        """Compute the signal-to-noise ratio in the corpus callosum.
55
56        Parameters
57        ----------
58        data_files : string
59            Path to the dwi.nii.gz file. This path may contain wildcards to
60            process multiple inputs at once.
61        bvals_files : string
62            Path of bvals.
63        bvecs_files : string
64            Path of bvecs.
65        mask_file : string
66            Path of a brain mask file.
67        bbox_threshold : variable float, optional
68            Threshold for bounding box, values separated with commas for ex.
69            [0.6,1,0,0.1,0,0.1].
70        out_dir : string, optional
71            Where the resulting file will be saved. (default current directory)
72        out_file : string, optional
73            Name of the result file to be saved.
74        out_mask_cc : string, optional
75            Name of the CC mask volume to be saved.
76        out_mask_noise : string, optional
77            Name of the mask noise volume to be saved.
78
79        """
80        io_it = self.get_io_iterator()
81
82        for dwi_path, bvals_path, bvecs_path, mask_path, out_path, \
83                cc_mask_path, mask_noise_path in io_it:
84            data, affine = load_nifti(dwi_path)
85            bvals, bvecs = read_bvals_bvecs(bvals_path, bvecs_path)
86            gtab = gradient_table(bvals=bvals, bvecs=bvecs)
87
88            mask, affine = load_nifti(mask_path)
89
90            logging.info('Computing tensors...')
91            tenmodel = TensorModel(gtab)
92            tensorfit = tenmodel.fit(data, mask=mask)
93
94            logging.info(
95                'Computing worst-case/best-case SNR using the CC...')
96
97            if np.ndim(data) == 4:
98                CC_box = np.zeros_like(data[..., 0])
99            elif np.ndim(data) == 3:
100                CC_box = np.zeros_like(data)
101            else:
102                raise IOError('DWI data has invalid dimensions')
103
104            mins, maxs = bounding_box(mask)
105            mins = np.array(mins)
106            maxs = np.array(maxs)
107            diff = (maxs - mins) // 4
108            bounds_min = mins + diff
109            bounds_max = maxs - diff
110
111            CC_box[bounds_min[0]:bounds_max[0],
112                   bounds_min[1]:bounds_max[1],
113                   bounds_min[2]:bounds_max[2]] = 1
114
115            if len(bbox_threshold) != 6:
116                raise IOError('bbox_threshold should have 6 float values')
117
118            mask_cc_part, cfa = segment_from_cfa(tensorfit, CC_box,
119                                                 bbox_threshold,
120                                                 return_cfa=True)
121
122            if not np.count_nonzero(mask_cc_part.astype(np.uint8)):
123                logging.warning("Empty mask: corpus callosum not found."
124                                " Update your data or your threshold")
125
126            save_nifti(cc_mask_path, mask_cc_part.astype(np.uint8), affine)
127            logging.info('CC mask saved as {0}'.format(cc_mask_path))
128
129            masked_data = data[mask_cc_part]
130            mean_signal = 0
131            if masked_data.size:
132                mean_signal = np.mean(masked_data, axis=0)
133            mask_noise = binary_dilation(mask, iterations=10)
134            mask_noise[..., :mask_noise.shape[-1]//2] = 1
135            mask_noise = ~mask_noise
136
137            save_nifti(mask_noise_path, mask_noise.astype(np.uint8), affine)
138            logging.info('Mask noise saved as {0}'.format(mask_noise_path))
139
140            noise_std = 0
141            if np.count_nonzero(mask_noise.astype(np.uint8)):
142                noise_std = np.std(data[mask_noise, :])
143
144            logging.info('Noise standard deviation sigma= ' + str(noise_std))
145
146            idx = np.sum(gtab.bvecs, axis=-1) == 0
147            gtab.bvecs[idx] = np.inf
148            axis_X = np.argmin(
149                np.sum((gtab.bvecs-np.array([1, 0, 0])) ** 2, axis=-1))
150            axis_Y = np.argmin(
151                np.sum((gtab.bvecs-np.array([0, 1, 0])) ** 2, axis=-1))
152            axis_Z = np.argmin(
153                np.sum((gtab.bvecs-np.array([0, 0, 1])) ** 2, axis=-1))
154
155            SNR_output = []
156            SNR_directions = []
157            for direction in ['b0', axis_X, axis_Y, axis_Z]:
158                if direction == 'b0':
159                    SNR = mean_signal[0]/noise_std if noise_std else 0
160                    logging.info("SNR for the b=0 image is :" + str(SNR))
161                else:
162                    logging.info("SNR for direction " + str(direction) +
163                                 " " + str(gtab.bvecs[direction]) + "is :" +
164                                 str(SNR))
165                    SNR_directions.append(direction)
166                    SNR = mean_signal[direction]/noise_std if noise_std else 0
167                SNR_output.append(SNR)
168
169            data = []
170            data.append({
171                        'data': str(SNR_output[0]) + ' ' + str(SNR_output[1]) +
172                        ' ' + str(SNR_output[2]) + ' ' + str(SNR_output[3]),
173                        'directions': 'b0' + ' ' + str(SNR_directions[0]) +
174                        ' ' + str(SNR_directions[1]) + ' ' +
175                        str(SNR_directions[2])
176                        })
177
178            with open(os.path.join(out_dir, out_path), 'w') as myfile:
179                json.dump(data, myfile)
180
181
182def buan_bundle_profiles(model_bundle_folder, bundle_folder,
183                         orig_bundle_folder, metric_folder, group_id, subject,
184                         no_disks=100, out_dir=''):
185    """
186    Applies statistical analysis on bundles and saves the results
187    in a directory specified by ``out_dir``.
188
189    Parameters
190    ----------
191    model_bundle_folder : string
192        Path to the input model bundle files. This path may contain
193        wildcards to process multiple inputs at once.
194    bundle_folder : string
195        Path to the input bundle files in common space. This path may
196        contain wildcards to process multiple inputs at once.
197    orig_folder : string
198        Path to the input bundle files in native space. This path may
199        contain wildcards to process multiple inputs at once.
200    metric_folder : string
201        Path to the input dti metric or/and peak files. It will be used as
202        metric for statistical analysis of bundles.
203    group_id : integer
204        what group subject belongs to either 0 for control or 1 for patient.
205    subject : string
206        subject id e.g. 10001.
207    no_disks : integer, optional
208        Number of disks used for dividing bundle into disks.
209    out_dir : string, optional
210        Output directory. (default current directory)
211
212    References
213    ----------
214    .. [Chandio2020] Chandio, B.Q., Risacher, S.L., Pestilli, F., Bullock, D.,
215    Yeh, FC., Koudoro, S., Rokem, A., Harezlak, J., and Garyfallidis, E.
216    Bundle analytics, a computational framework for investigating the
217    shapes and profiles of brain pathways across populations.
218    Sci Rep 10, 17149 (2020)
219
220    """
221
222    t = time()
223
224    dt = dict()
225
226    mb = glob(os.path.join(model_bundle_folder, "*.trk"))
227    print(mb)
228
229    mb.sort()
230
231    bd = glob(os.path.join(bundle_folder, "*.trk"))
232
233    bd.sort()
234    print(bd)
235    org_bd = glob(os.path.join(orig_bundle_folder, "*.trk"))
236    org_bd.sort()
237    print(org_bd)
238    n = len(org_bd)
239    n = len(mb)
240
241    for io in range(n):
242
243        mbundles = load_tractogram(mb[io], reference='same',
244                                   bbox_valid_check=False).streamlines
245        bundles = load_tractogram(bd[io], reference='same',
246                                  bbox_valid_check=False).streamlines
247        orig_bundles = load_tractogram(org_bd[io], reference='same',
248                                       bbox_valid_check=False).streamlines
249
250        if len(orig_bundles) > 5:
251
252            indx = assignment_map(bundles, mbundles, no_disks)
253            ind = np.array(indx)
254
255            metric_files_names_dti = glob(os.path.join(metric_folder,
256                                                       "*.nii.gz"))
257
258            metric_files_names_csa = glob(os.path.join(metric_folder,
259                                                       "*.pam5"))
260
261            _, affine = load_nifti(metric_files_names_dti[0])
262
263            affine_r = np.linalg.inv(affine)
264            transformed_orig_bundles = transform_streamlines(orig_bundles,
265                                                             affine_r)
266
267            for mn in range(len(metric_files_names_dti)):
268
269                ab = os.path.split(metric_files_names_dti[mn])
270                metric_name = ab[1]
271
272                fm = metric_name[:-7]
273                bm = os.path.split(mb[io])[1][:-4]
274
275                logging.info("bm = " + bm)
276
277                dt = dict()
278
279                logging.info("metric = " + metric_files_names_dti[mn])
280
281                metric, _ = load_nifti(metric_files_names_dti[mn])
282
283                anatomical_measures(transformed_orig_bundles, metric, dt, fm,
284                                    bm, subject, group_id, ind, out_dir)
285
286            for mn in range(len(metric_files_names_csa)):
287                ab = os.path.split(metric_files_names_csa[mn])
288                metric_name = ab[1]
289
290                fm = metric_name[:-5]
291                bm = os.path.split(mb[io])[1][:-4]
292
293                logging.info("bm = " + bm)
294                logging.info("metric = " + metric_files_names_csa[mn])
295                dt = dict()
296                metric = load_peaks(metric_files_names_csa[mn])
297
298                peak_values(transformed_orig_bundles, metric, dt, fm, bm,
299                            subject, group_id, ind, out_dir)
300
301    print("total time taken in minutes = ", (-t + time())/60)
302
303
304class BundleAnalysisTractometryFlow(Workflow):
305    @classmethod
306    def get_short_name(cls):
307        return 'ba'
308
309    def run(self, model_bundle_folder, subject_folder, no_disks=100,
310            out_dir=''):
311        """Workflow of bundle analytics.
312
313        Applies statistical analysis on bundles of subjects and saves the
314        results in a directory specified by ``out_dir``.
315
316        Parameters
317        ----------
318
319        model_bundle_folder : string
320            Path to the input model bundle files. This path may
321            contain wildcards to process multiple inputs at once.
322
323        subject_folder : string
324            Path to the input subject folder. This path may contain
325            wildcards to process multiple inputs at once.
326
327        no_disks : integer, optional
328            Number of disks used for dividing bundle into disks.
329
330        out_dir : string, optional
331            Output directory. (default current directory)
332
333        References
334        ----------
335        .. [Chandio2020] Chandio, B.Q., Risacher, S.L., Pestilli, F.,
336        Bullock, D., Yeh, FC., Koudoro, S., Rokem, A., Harezlak, J., and
337        Garyfallidis, E. Bundle analytics, a computational framework for
338        investigating the shapes and profiles of brain pathways across
339        populations. Sci Rep 10, 17149 (2020)
340
341        """
342
343        if os.path.isdir(subject_folder) is False:
344            raise ValueError("Invalid path to subjects")
345
346        groups = os.listdir(subject_folder)
347        groups.sort()
348        for group in groups:
349            if os.path.isdir(os.path.join(subject_folder, group)):
350                logging.info('group = {0}'.format(group))
351                all_subjects = os.listdir(os.path.join(subject_folder, group))
352                all_subjects.sort()
353                logging.info(all_subjects)
354            if group.lower() == 'patient':
355                group_id = 1  # 1 means patient
356            elif group.lower() == 'control':
357                group_id = 0  # 0 means control
358            else:
359                print(group)
360                raise ValueError("Invalid group. Neither patient nor control")
361
362            for sub in all_subjects:
363                logging.info(sub)
364                pre = os.path.join(subject_folder, group, sub)
365                logging.info(pre)
366                b = os.path.join(pre, "rec_bundles")
367                c = os.path.join(pre, "org_bundles")
368                d = os.path.join(pre, "anatomical_measures")
369                buan_bundle_profiles(model_bundle_folder, b, c, d, group_id,
370                                     sub, no_disks, out_dir)
371
372
373class LinearMixedModelsFlow(Workflow):
374    @classmethod
375    def get_short_name(cls):
376        return 'lmm'
377
378    def get_metric_name(self, path):
379        """ Splits the path string and returns name of anatomical measure
380        (eg: fa), bundle name eg(AF_L) and bundle name with metric name
381        (eg: AF_L_fa)
382
383        Parameters
384        ----------
385        path : string
386            Path to the input metric files. This path may
387            contain wildcards to process multiple inputs at once.
388        """
389
390        head_tail = os.path.split(path)
391        name = head_tail[1]
392        count = 0
393        i = len(name)-1
394        while i > 0:
395            if name[i] == '.':
396                count = i
397                break
398            i = i-1
399
400        for j in range(len(name)):
401            if name[j] == '_':
402                if name[j+1] != 'L' and name[j+1] != 'R' and name[j+1] != 'F':
403
404                    return name[j+1:count], name[:j], name[:count]
405
406        return " ", " ", " "
407
408    def save_lmm_plot(self, plot_file, title, bundle_name, x, y):
409        """ Saves LMM plot with segment/disk number on x-axis and
410        -log10(pvalues) on y-axis in out_dir folder.
411
412        Parameters
413        ----------
414        plot_file : string
415            Path to the plot file. This path may
416            contain wildcards to process multiple inputs at once.
417        title : string
418            Title for the plot.
419        bundle_name : string
420        x : list
421            list containing segment/disk number for x-axis.
422        y : list
423            list containing -log10(pvalues) per segment/disk number for y-axis.
424
425        """
426
427        n = len(x)
428        dotted = np.ones(n)
429        dotted[:] = 2
430        c1 = np.random.rand(1, 3)
431
432        y_pos = np.arange(n)
433
434        l1, = plt.plot(y_pos, dotted, color='red', marker='.',
435                       linestyle='solid', linewidth=0.6,
436                       markersize=0.7, label="p-value < 0.01")
437
438        l2, = plt.plot(y_pos, dotted+1, color='black', marker='.',
439                       linestyle='solid', linewidth=0.4,
440                       markersize=0.4, label="p-value < 0.001")
441
442        first_legend = plt.legend(handles=[l1, l2],
443                                  loc='upper right')
444
445        axes = plt.gca()
446        axes.add_artist(first_legend)
447        axes.set_ylim([0, 6])
448
449        l3 = plt.bar(y_pos, y, color=c1, alpha=0.5,
450                               label=bundle_name)
451        plt.legend(handles=[l3], loc='upper left')
452        plt.title(title.upper())
453        plt.xlabel("Segment Number")
454        plt.ylabel("-log10(Pvalues)")
455        plt.savefig(plot_file)
456        plt.clf()
457
458    def run(self, h5_files, no_disks=100, out_dir=''):
459        """Workflow of linear Mixed Models.
460
461        Applies linear Mixed Models on bundles of subjects and saves the
462        results in a directory specified by ``out_dir``.
463
464        Parameters
465        ----------
466
467        h5_files : string
468            Path to the input metric files. This path may
469            contain wildcards to process multiple inputs at once.
470
471        no_disks : integer, optional
472            Number of disks used for dividing bundle into disks.
473
474        out_dir : string, optional
475            Output directory. (default current directory)
476
477        """
478
479        io_it = self.get_io_iterator()
480
481        for file_path in io_it:
482
483            logging.info('Applying metric {0}'.format(file_path))
484
485            file_name, bundle_name, save_name = self.get_metric_name(file_path)
486            logging.info(" file name = " + file_name)
487            logging.info("file path = " + file_path)
488
489            pvalues = np.zeros(no_disks)
490            warnings.filterwarnings("ignore")
491            # run mixed linear model for every disk
492            for i in range(no_disks):
493                disk_count = i+1
494                df = pd.read_hdf(file_path, where='disk=disk_count')
495
496                logging.info("read the dataframe for disk number " +
497                             str(disk_count))
498                # check if data has significant data to perform LMM
499                if len(df) < 10:
500                    raise ValueError("Dataset for Linear Mixed Model is too small")
501
502                criteria = file_name + " ~ group"
503                md = smf.mixedlm(criteria, df,
504                                 groups=df["subject"])
505
506                mdf = md.fit()
507
508                pvalues[i] = mdf.pvalues[1]
509
510            x = list(range(1, len(pvalues)+1))
511            y = -1*np.log10(pvalues)
512
513            save_file = os.path.join(out_dir, save_name + "_pvalues.npy")
514            np.save(save_file, pvalues)
515
516            save_file = os.path.join(out_dir, save_name + "_pvalues_log.npy")
517            np.save(save_file, y)
518
519            save_file = os.path.join(out_dir, save_name + ".png")
520            self.save_lmm_plot(save_file, file_name, bundle_name, x, y)
521
522
523class BundleShapeAnalysis(Workflow):
524    @classmethod
525    def get_short_name(cls):
526        return 'BS'
527
528    def run(self, subject_folder, clust_thr=[5, 3, 1.5], threshold=6,
529            out_dir=''):
530        """Workflow of bundle analytics.
531
532        Applies bundle shape similarity analysis on bundles of subjects and
533        saves the results in a directory specified by ``out_dir``.
534
535        Parameters
536        ----------
537
538        subject_folder : string
539            Path to the input subject folder. This path may contain
540            wildcards to process multiple inputs at once.
541
542        clust_thr : variable float, optional
543            list of bundle clustering thresholds used in QuickBundlesX.
544
545        threshold : float, optional
546            Bundle shape similarity threshold.
547
548        out_dir : string, optional
549            Output directory. (default current directory)
550
551        References
552        ----------
553        .. [Chandio2020] Chandio, B.Q., Risacher, S.L., Pestilli, F.,
554        Bullock, D., Yeh, FC., Koudoro, S., Rokem, A., Harezlak, J., and
555        Garyfallidis, E. Bundle analytics, a computational framework for
556        investigating the shapes and profiles of brain pathways across
557        populations. Sci Rep 10, 17149 (2020)
558
559        """
560        rng = np.random.RandomState()
561        all_subjects = []
562        if os.path.isdir(subject_folder):
563            groups = os.listdir(subject_folder)
564            groups.sort()
565        else:
566            raise ValueError("Not a directory")
567
568        for group in groups:
569
570            if os.path.isdir(os.path.join(subject_folder, group)):
571                subjects = os.listdir(os.path.join(subject_folder, group))
572                subjects.sort()
573                logging.info("first " + str(len(subjects)) +
574                             " subjects in matrix belong to " + group +
575                             " group")
576
577                for sub in subjects:
578                    dpath = os.path.join(subject_folder, group, sub)
579                    if os.path.isdir(dpath):
580                        all_subjects.append(dpath)
581
582        N = len(all_subjects)
583
584        bundles = os.listdir(os.path.join(all_subjects[0], "rec_bundles"))
585        for bun in bundles:
586            # bundle shape similarity matrix
587            ba_matrix = np.zeros((N, N))
588            i = 0
589            logging.info(bun)
590            for sub in all_subjects:
591                j = 0
592
593                bundle1 = load_tractogram(os.path.join(sub, "rec_bundles",
594                                                       bun), reference='same',
595                                          bbox_valid_check=False).streamlines
596
597                for subi in all_subjects:
598                    logging.info(subi)
599
600                    bundle2 = load_tractogram(os.path.join(subi, "rec_bundles",
601                                                           bun),
602                                              reference='same',
603                                              bbox_valid_check=False).streamlines
604
605                    ba_value = bundle_shape_similarity(bundle1, bundle2, rng,
606                                                       clust_thr, threshold)
607
608                    ba_matrix[i][j] = ba_value
609
610                    j += 1
611                i += 1
612            logging.info("saving BA score matrix")
613            np.save(os.path.join(out_dir, bun[:-4]+".npy"), ba_matrix)
614
615            cmap = matplt.cm.get_cmap('Blues')
616            plt.title(bun[:-4])
617            plt.imshow(ba_matrix, cmap=cmap)
618            plt.colorbar()
619            plt.clim(0, 1)
620            plt.savefig(os.path.join(out_dir, "SM_"+bun[:-4]))
621            plt.clf()
622