1import numpy as np
2from scipy.ndimage import map_coordinates
3from scipy.fftpack import fftn, fftshift, ifftshift
4from dipy.reconst.odf import OdfModel, OdfFit
5from dipy.reconst.cache import Cache
6from dipy.reconst.multi_voxel import multi_voxel_fit
7
8
9class DiffusionSpectrumModel(OdfModel, Cache):
10
11    def __init__(self,
12                 gtab,
13                 qgrid_size=17,
14                 r_start=2.1,
15                 r_end=6.,
16                 r_step=0.2,
17                 filter_width=32,
18                 normalize_peaks=False):
19        r""" Diffusion Spectrum Imaging
20
21        The theoretical idea underlying this method is that the diffusion
22        propagator $P(\mathbf{r})$ (probability density function of the average
23        spin displacements) can be estimated by applying 3D FFT to the signal
24        values $S(\mathbf{q})$
25
26        ..math::
27            :nowrap:
28                \begin{eqnarray}
29                    P(\mathbf{r}) & = & S_{0}^{-1}\int S(\mathbf{q})\exp(-i2\pi\mathbf{q}\cdot\mathbf{r})d\mathbf{r}
30                \end{eqnarray}
31
32        where $\mathbf{r}$ is the displacement vector and $\mathbf{q}$ is the
33        wave vector which corresponds to different gradient directions. Method
34        used to calculate the ODFs. Here we implement the method proposed by
35        Wedeen et al. [1]_.
36
37        The main assumption for this model is fast gradient switching and that
38        the acquisition gradients will sit on a keyhole Cartesian grid in
39        q_space [3]_.
40
41        Parameters
42        ----------
43        gtab : GradientTable,
44            Gradient directions and bvalues container class
45        qgrid_size : int,
46            has to be an odd number. Sets the size of the q_space grid.
47            For example if qgrid_size is 17 then the shape of the grid will be
48            ``(17, 17, 17)``.
49        r_start : float,
50            ODF is sampled radially in the PDF. This parameters shows where the
51            sampling should start.
52        r_end : float,
53            Radial endpoint of ODF sampling
54        r_step : float,
55            Step size of the ODf sampling from r_start to r_end
56        filter_width : float,
57            Strength of the hanning filter
58
59        References
60        ----------
61        .. [1]  Wedeen V.J et al., "Mapping Complex Tissue Architecture With
62        Diffusion Spectrum Magnetic Resonance Imaging", MRM 2005.
63
64        .. [2] Canales-Rodriguez E.J et al., "Deconvolution in Diffusion
65        Spectrum Imaging", Neuroimage, 2010.
66
67        .. [3] Garyfallidis E, "Towards an accurate brain tractography", PhD
68        thesis, University of Cambridge, 2012.
69
70        Examples
71        --------
72        In this example where we provide the data, a gradient table
73        and a reconstruction sphere, we calculate generalized FA for the first
74        voxel in the data with the reconstruction performed using DSI.
75
76        >>> import warnings
77        >>> from dipy.data import dsi_voxels, default_sphere
78        >>> data, gtab = dsi_voxels()
79        >>> from dipy.reconst.dsi import DiffusionSpectrumModel
80        >>> ds = DiffusionSpectrumModel(gtab)
81        >>> dsfit = ds.fit(data)
82        >>> from dipy.reconst.odf import gfa
83        >>> np.round(gfa(dsfit.odf(default_sphere))[0, 0, 0], 2)
84        0.11
85
86        Notes
87        ------
88        A. Have in mind that DSI expects gradients on both hemispheres. If your
89        gradients span only one hemisphere you need to duplicate the data and
90        project them to the other hemisphere before calling this class. The
91        function dipy.reconst.dsi.half_to_full_qspace can be used for this
92        purpose.
93
94        B. If you increase the size of the grid (parameter qgrid_size) you will
95        most likely also need to update the r_* parameters. This is because
96        the added zero padding from the increase of gqrid_size also introduces
97        a scaling of the PDF.
98
99        C. We assume that data only one b0 volume is provided.
100
101        See Also
102        --------
103        dipy.reconst.gqi.GeneralizedQSampling
104
105        """
106
107        self.bvals = gtab.bvals
108        self.bvecs = gtab.bvecs
109        self.normalize_peaks = normalize_peaks
110        # 3d volume for Sq
111        if qgrid_size % 2 == 0:
112            raise ValueError('qgrid_size needs to be an odd integer')
113        self.qgrid_size = qgrid_size
114        # necessary shifting for centering
115        self.origin = self.qgrid_size // 2
116
117        # hanning filter width
118        self.filter = hanning_filter(gtab, filter_width, self.origin)
119        # odf sampling radius
120        self.qradius = np.arange(r_start, r_end, r_step)
121        self.qradiusn = len(self.qradius)
122        # create qspace grid
123        self.qgrid = create_qspace(gtab, self.origin)
124        b0 = np.min(self.bvals)
125        self.dn = (self.bvals > b0).sum()
126        self.gtab = gtab
127
128    @multi_voxel_fit
129    def fit(self, data):
130        return DiffusionSpectrumFit(self, data)
131
132
133class DiffusionSpectrumFit(OdfFit):
134
135    def __init__(self, model, data):
136        """ Calculates PDF and ODF and other properties for a single voxel
137
138        Parameters
139        ----------
140        model : object,
141            DiffusionSpectrumModel
142        data : 1d ndarray,
143            signal values
144        """
145        self.model = model
146        self.data = data
147        self.qgrid_sz = self.model.qgrid_size
148        self.dn = self.model.dn
149        self._gfa = None
150        self.npeaks = 5
151        self._peak_values = None
152        self._peak_indices = None
153
154    def pdf(self, normalized=True):
155        """ Applies the 3D FFT in the q-space grid to generate
156        the diffusion propagator
157        """
158        values = self.data * self.model.filter
159        # create the signal volume
160        Sq = np.zeros((self.qgrid_sz, self.qgrid_sz, self.qgrid_sz))
161        # fill q-space
162
163        for i in range(len(values)):
164            qx, qy, qz = self.model.qgrid[i]
165            Sq[qx, qy, qz] += values[i]
166        # apply fourier transform
167        Pr = fftshift(np.real(fftn(ifftshift(Sq), 3 * (self.qgrid_sz, ))))
168        # clipping negative values to 0 (ringing artefact)
169        Pr = np.clip(Pr, 0, Pr.max())
170
171        # normalize the propagator to obtain a pdf
172        if normalized:
173            Pr /= Pr.sum()
174
175        return Pr
176
177    def rtop_signal(self, filtering=True):
178        """ Calculates the return to origin probability (rtop) from the signal
179        rtop equals to the sum of all signal values
180
181        Parameters
182        ----------
183        filtering : boolean, optional
184            Whether to perform Hanning filtering. Default: True
185
186        Returns
187        -------
188        rtop : float
189            the return to origin probability
190        """
191
192        if filtering:
193            values = self.data * self.model.filter
194        else:
195            values = self.data
196
197        rtop = values.sum()
198
199        return rtop
200
201    def rtop_pdf(self, normalized=True):
202        r""" Calculates the return to origin probability from the propagator, which is
203        the propagator evaluated at zero (see Descoteaux et Al. [1]_,
204        Tuch [2]_, Wu et al. [3]_)
205        rtop = P(0)
206
207        Parameters
208        ----------
209        normalized : boolean, optional
210            Whether to normalize the propagator by its sum in order to obtain a
211            pdf. Default: True.
212
213        Returns
214        -------
215        rtop : float
216            the return to origin probability
217
218        References
219        ----------
220        .. [1] Descoteaux M. et al., "Multiple q-shell diffusion propagator
221        imaging", Medical Image Analysis, vol 15, No. 4, p. 603-621, 2011.
222
223        .. [2] Tuch D.S., "Diffusion MRI of Complex Tissue Structure",
224         PhD Thesis, 2002.
225
226        .. [3] Wu Y. et al., "Computation of Diffusion Function Measures
227        in q -Space Using Magnetic Resonance Hybrid Diffusion Imaging",
228        IEEE TRANSACTIONS ON MEDICAL IMAGING, vol. 27, No. 6, p. 858-865, 2008
229
230        """
231
232        Pr = self.pdf(normalized=normalized)
233
234        center = self.qgrid_sz // 2
235
236        rtop = Pr[center, center, center]
237        return rtop
238
239    def msd_discrete(self, normalized=True):
240        r""" Calculates the mean squared displacement on the discrete propagator
241
242        ..math::
243            :nowrap:
244                \begin{equation}
245                    MSD:{DSI}=\int_{-\infty}^{\infty}\int_{-\infty}^{\infty}\int_{-\infty}^{\infty} P(\hat{\mathbf{r}}) \cdot \hat{\mathbf{r}}^{2} \ dr_x \ dr_y \ dr_z
246                \end{equation}
247
248        where $\hat{\mathbf{r}}$ is a point in the 3D Propagator space
249        (see Wu et al. [1]_).
250
251        Parameters
252        ----------
253        normalized : boolean, optional
254            Whether to normalize the propagator by its sum in order to obtain a
255            pdf. Default: True
256
257        Returns
258        -------
259        msd : float
260            the mean square displacement
261
262        References
263        ----------
264        .. [1] Wu Y. et al., "Hybrid diffusion imaging", NeuroImage, vol 36,
265        p. 617-629, 2007.
266
267        """
268
269        Pr = self.pdf(normalized=normalized)
270
271        # create the r squared 3D matrix
272        gridsize = self.qgrid_sz
273        center = gridsize // 2
274        a = np.arange(gridsize) - center
275        x = np.tile(a, (gridsize, gridsize, 1))
276        y = np.tile(a.reshape(gridsize, 1), (gridsize, 1, gridsize))
277        z = np.tile(a.reshape(gridsize, 1, 1), (1, gridsize, gridsize))
278        r2 = x ** 2 + y ** 2 + z ** 2
279
280        msd = np.sum(Pr * r2) / float((gridsize ** 3))
281        return msd
282
283    def odf(self, sphere):
284        r""" Calculates the real discrete odf for a given discrete sphere
285
286        ..math::
287            :nowrap:
288                \begin{equation}
289                    \psi_{DSI}(\hat{\mathbf{u}})=\int_{0}^{\infty}P(r\hat{\mathbf{u}})r^{2}dr
290                \end{equation}
291
292        where $\hat{\mathbf{u}}$ is the unit vector which corresponds to a
293        sphere point.
294        """
295        interp_coords = self.model.cache_get('interp_coords',
296                                             key=sphere)
297        if interp_coords is None:
298            interp_coords = pdf_interp_coords(sphere,
299                                              self.model.qradius,
300                                              self.model.origin)
301            self.model.cache_set('interp_coords', sphere, interp_coords)
302
303        Pr = self.pdf()
304
305        # calculate the orientation distribution function
306        return pdf_odf(Pr, self.model.qradius, interp_coords)
307
308
309def create_qspace(gtab, origin):
310    """ create the 3D grid which holds the signal values (q-space)
311
312    Parameters
313    ----------
314    gtab : GradientTable
315    origin : (3,) ndarray
316        center of qspace
317
318    Returns
319    -------
320    qgrid : ndarray
321        qspace coordinates
322    """
323    # create the q-table from bvecs and bvals
324    qtable = create_qtable(gtab, origin)
325
326    # center and index in qspace volume
327    qgrid = qtable + origin
328    return qgrid.astype('i8')
329
330
331def create_qtable(gtab, origin):
332    """ create a normalized version of gradients
333
334    Parameters
335    ----------
336    gtab : GradientTable
337    origin : (3,) ndarray
338        center of qspace
339
340    Returns
341    -------
342    qtable : ndarray
343    """
344
345    bv = gtab.bvals
346    bsorted = np.sort(bv[np.bitwise_not(gtab.b0s_mask)])
347    for i in range(len(bsorted)):
348        bmin = bsorted[i]
349        try:
350            if np.sqrt(bv.max() / bmin) > origin + 1:
351                continue
352            else:
353                break
354        except ZeroDivisionError:
355            continue
356
357    bv = np.sqrt(bv / bmin)
358    qtable = np.vstack((bv, bv, bv)).T * gtab.bvecs
359    return np.floor(qtable + .5)
360
361
362def hanning_filter(gtab, filter_width, origin):
363    """ create a hanning window
364
365    The signal is premultiplied by a Hanning window before
366    Fourier transform in order to ensure a smooth attenuation
367    of the signal at high q values.
368
369    Parameters
370    ----------
371    gtab : GradientTable
372    filter_width : int
373    origin : (3,) ndarray
374        center of qspace
375
376    Returns
377    -------
378    filter : (N,) ndarray
379        where N is the number of non-b0 gradient directions
380
381    """
382    qtable = create_qtable(gtab, origin)
383    # calculate r - hanning filter free parameter
384    r = np.sqrt(qtable[:, 0] ** 2 + qtable[:, 1] ** 2 + qtable[:, 2] ** 2)
385    # setting hanning filter width and hanning
386    return .5 * np.cos(2 * np.pi * r / filter_width)
387
388
389def pdf_interp_coords(sphere, rradius, origin):
390    """ Precompute coordinates for ODF calculation from the PDF
391
392    Parameters
393    ----------
394    sphere : object,
395            Sphere
396    rradius : array, shape (N,)
397            line interpolation points
398    origin : array, shape (3,)
399            center of the grid
400
401    """
402    interp_coords = rradius * sphere.vertices[np.newaxis].T
403    origin = np.reshape(origin, [-1, 1, 1])
404    interp_coords = origin + interp_coords
405    return interp_coords
406
407
408def pdf_odf(Pr, rradius, interp_coords):
409    r""" Calculates the real ODF from the diffusion propagator(PDF) Pr
410
411    Parameters
412    ----------
413    Pr : array, shape (X, X, X)
414        probability density function
415    rradius : array, shape (N,)
416        interpolation range on the radius
417    interp_coords : array, shape (3, M, N)
418        coordinates in the pdf for interpolating the odf
419    """
420    PrIs = map_coordinates(Pr, interp_coords, order=1)
421    odf = (PrIs * rradius ** 2).sum(-1)
422    return odf
423
424
425def half_to_full_qspace(data, gtab):
426    """ Half to full Cartesian grid mapping
427
428    Useful when dMRI data are provided in one qspace hemisphere as
429    DiffusionSpectrum expects data to be in full qspace.
430
431    Parameters
432    ----------
433    data : array, shape (X, Y, Z, W)
434        where (X, Y, Z) volume size and W number of gradient directions
435    gtab : GradientTable
436        container for b-values and b-vectors (gradient directions)
437
438    Returns
439    -------
440    new_data : array, shape (X, Y, Z, 2 * W -1)
441    new_gtab : GradientTable
442
443    Notes
444    -----
445    We assume here that only on b0 is provided with the initial data. If that
446    is not the case then you will need to write your own preparation function
447    before providing the gradients and the data to the DiffusionSpectrumModel
448    class.
449    """
450    bvals = gtab.bvals
451    bvecs = gtab.bvecs
452    bvals = np.append(bvals, bvals[1:])
453    bvecs = np.append(bvecs, - bvecs[1:], axis=0)
454    data = np.append(data, data[..., 1:], axis=-1)
455    gtab.bvals = bvals.copy()
456    gtab.bvecs = bvecs.copy()
457    return data, gtab
458
459
460def project_hemisph_bvecs(gtab):
461    """ Project any near identical bvecs to the other hemisphere
462
463    Parameters
464    ----------
465    gtab : object,
466            GradientTable
467
468    Notes
469    -------
470    Useful only when working with some types of dsi data.
471    """
472    bvals = gtab.bvals
473    bvecs = gtab.bvecs
474    bvs = bvals[1:]
475    bvcs = bvecs[1:]
476    b = bvs[:, None] * bvcs
477    bb = np.zeros((len(bvs), len(bvs)))
478    pairs = []
479    for (i, vec) in enumerate(b):
480        for (j, vec2) in enumerate(b):
481            bb[i, j] = np.sqrt(np.sum((vec - vec2) ** 2))
482        I = np.argsort(bb[i])
483        for j in I:
484            if j != i:
485                break
486        if (j, i) in pairs:
487            pass
488        else:
489            pairs.append((i, j))
490    bvecs2 = bvecs.copy()
491    for (i, j) in pairs:
492        bvecs2[1 + j] = - bvecs2[1 + j]
493    return bvecs2, pairs
494
495
496class DiffusionSpectrumDeconvModel(DiffusionSpectrumModel):
497
498    def __init__(self, gtab, qgrid_size=35, r_start=4.1, r_end=13.,
499                 r_step=0.4, filter_width=np.inf, normalize_peaks=False):
500        r""" Diffusion Spectrum Deconvolution
501
502        The idea is to remove the convolution on the DSI propagator that is
503        caused by the truncation of the q-space in the DSI sampling.
504
505        ..math::
506            :nowrap:
507                \begin{eqnarray*}
508                    P_{dsi}(\mathbf{r}) & = & S_{0}^{-1}\iiint\limits_{\| \mathbf{q} \| \le \mathbf{q_{max}}} S(\mathbf{q})\exp(-i2\pi\mathbf{q}\cdot\mathbf{r})d\mathbf{q} \\
509                    & = & S_{0}^{-1}\iiint\limits_{\mathbf{q}} \left( S(\mathbf{q}) \cdot M(\mathbf{q}) \right) \exp(-i2\pi\mathbf{q}\cdot\mathbf{r})d\mathbf{q} \\
510                    & = & P(\mathbf{r}) \otimes \left( S_{0}^{-1}\iiint\limits_{\mathbf{q}}  M(\mathbf{q}) \exp(-i2\pi\mathbf{q}\cdot\mathbf{r})d\mathbf{q} \right) \\
511                \end{eqnarray*}
512
513        where $\mathbf{r}$ is the displacement vector and $\mathbf{q}$ is the
514        wave vector which corresponds to different gradient directions,
515        $M(\mathbf{q})$ is a mask corresponding to your q-space sampling and
516        $\otimes$ is the convolution operator [1]_.
517
518
519        Parameters
520        ----------
521        gtab : GradientTable,
522            Gradient directions and bvalues container class
523        qgrid_size : int,
524            has to be an odd number. Sets the size of the q_space grid.
525            For example if qgrid_size is 35 then the shape of the grid will be
526            ``(35, 35, 35)``.
527        r_start : float,
528            ODF is sampled radially in the PDF. This parameters shows where the
529            sampling should start.
530        r_end : float,
531            Radial endpoint of ODF sampling
532        r_step : float,
533            Step size of the ODf sampling from r_start to r_end
534        filter_width : float,
535            Strength of the hanning filter
536
537        References
538        ----------
539        .. [1] Canales-Rodriguez E.J et al., "Deconvolution in Diffusion
540        Spectrum Imaging", Neuroimage, 2010.
541
542        .. [2] Biggs David S.C. et al., "Acceleration of Iterative Image
543        Restoration Algorithms", Applied Optics, vol. 36, No. 8, p. 1766-1775,
544        1997.
545
546        """
547        DiffusionSpectrumModel.__init__(self, gtab, qgrid_size,
548                                        r_start, r_end, r_step,
549                                        filter_width,
550                                        normalize_peaks)
551
552    @multi_voxel_fit
553    def fit(self, data):
554        return DiffusionSpectrumDeconvFit(self, data)
555
556
557class DiffusionSpectrumDeconvFit(DiffusionSpectrumFit):
558
559    def pdf(self):
560        """ Applies the 3D FFT in the q-space grid to generate
561        the DSI diffusion propagator, remove the background noise with a
562        hard threshold and then deconvolve the propagator with the
563        Lucy-Richardson deconvolution algorithm
564        """
565        values = self.data
566        # create the signal volume
567        Sq = np.zeros((self.qgrid_sz, self.qgrid_sz, self.qgrid_sz))
568        # fill q-space
569        for i in range(len(values)):
570            qx, qy, qz = self.model.qgrid[i]
571            Sq[qx, qy, qz] += values[i]
572        # get deconvolution PSF
573        DSID_PSF = self.model.cache_get('deconv_psf', key=self.model.gtab)
574        if DSID_PSF is None:
575            DSID_PSF = gen_PSF(self.model.qgrid, self.qgrid_sz,
576                               self.qgrid_sz, self.qgrid_sz)
577        self.model.cache_set('deconv_psf', self.model.gtab, DSID_PSF)
578        # apply fourier transform
579        Pr = fftshift(np.abs(np.real(fftn(ifftshift(Sq),
580                                          3 * (self.qgrid_sz, )))))
581        # threshold propagator
582        Pr = threshold_propagator(Pr)
583        # apply LR deconvolution
584        Pr = LR_deconv(Pr, DSID_PSF, 5, 2)
585        return Pr
586
587
588def threshold_propagator(P, estimated_snr=15.):
589    """
590    Applies hard threshold on the propagator to remove background noise for the
591    deconvolution.
592    """
593    P_thresholded = P.copy()
594    threshold = P_thresholded.max() / float(estimated_snr)
595    P_thresholded[P_thresholded < threshold] = 0
596    return P_thresholded / P_thresholded.sum()
597
598
599def gen_PSF(qgrid_sampling, siz_x, siz_y, siz_z):
600    """
601    Generate a PSF for DSI Deconvolution by taking the ifft of the binary
602    q-space sampling mask and truncating it to keep only the center.
603    """
604    Sq = np.zeros((siz_x, siz_y, siz_z))
605    # fill q-space
606    for i in range(qgrid_sampling.shape[0]):
607        qx, qy, qz = qgrid_sampling[i]
608        Sq[qx, qy, qz] = 1
609    return Sq * np.real(np.fft.fftshift(np.fft.ifftn(np.fft.ifftshift(Sq))))
610
611
612def LR_deconv(prop, psf, numit=5, acc_factor=1):
613    r"""
614    Perform Lucy-Richardson deconvolution algorithm on a 3D array.
615
616    Parameters
617    ----------
618    prop : 3-D ndarray of dtype float
619        The 3D volume to be deconvolve
620    psf : 3-D ndarray of dtype float
621        The filter that will be used for the deconvolution.
622    numit : int
623        Number of Lucy-Richardson iteration to perform.
624    acc_factor : float
625        Exponential acceleration factor as in [1]_.
626
627    References
628    ----------
629    .. [1] Biggs David S.C. et al., "Acceleration of Iterative Image
630       Restoration Algorithms", Applied Optics, vol. 36, No. 8, p. 1766-1775,
631       1997.
632
633    """
634
635    eps = 1e-16
636    # Create the otf of the same size as prop
637    otf = np.zeros_like(prop)
638    # prop.ndim==3
639    otf[otf.shape[0] // 2 - psf.shape[0] // 2:otf.shape[0] // 2 +
640        psf.shape[0] // 2 + 1, otf.shape[1] // 2 - psf.shape[1] // 2:
641        otf.shape[1] // 2 + psf.shape[1] // 2 + 1, otf.shape[2] // 2 -
642        psf.shape[2] // 2:otf.shape[2] // 2 + psf.shape[2] // 2 + 1] = psf
643    otf = np.real(np.fft.fftn(np.fft.ifftshift(otf)))
644    # Enforce Positivity
645    prop = np.clip(prop, 0, np.inf)
646    prop_deconv = prop.copy()
647    for it in range(numit):
648        # Blur the estimate
649        reBlurred = np.real(np.fft.ifftn(otf * np.fft.fftn(prop_deconv)))
650        reBlurred[reBlurred < eps] = eps
651        # Update the estimate
652        prop_deconv = prop_deconv * (
653            np.real(np.fft.ifftn(
654                otf * np.fft.fftn((prop / reBlurred) + eps)))) ** acc_factor
655        # Enforce positivity
656        prop_deconv = np.clip(prop_deconv, 0, np.inf)
657    return prop_deconv / prop_deconv.sum()
658
659
660if __name__ == '__main__':
661    pass
662