1import numpy as np
2cimport numpy as cnp
3cimport cython
4
5cimport safe_openmp as openmp
6from safe_openmp cimport have_openmp
7from cython.parallel import parallel, prange, threadid
8from libc.stdlib cimport malloc, free
9
10from dipy.denoise.enhancement_kernel import EnhancementKernel
11from dipy.data import get_sphere
12from dipy.reconst.shm import sh_to_sf, sf_to_sh
13
14from dipy.utils.omp import cpu_count, determine_num_threads
15from dipy.utils.omp cimport set_num_threads, restore_default_num_threads
16
17def convolve(odfs_sh, kernel, sh_order, test_mode=False, num_threads=None, normalize=True):
18    """ Perform the shift-twist convolution with the ODF data and
19    the lookup-table of the kernel.
20
21    Parameters
22    ----------
23    odfs : array of double
24        The ODF data in spherical harmonics format
25    kernel : array of double
26        The 5D lookup table
27    sh_order : integer
28        Maximal spherical harmonics order
29    test_mode : boolean
30        Reduced convolution in one direction only for testing
31    num_threads : int, optional
32        Number of threads to be used for OpenMP parallelization. If None
33        (default) the value of OMP_NUM_THREADS environment variable is used
34        if it is set, otherwise all available threads are used. If < 0 the
35        maximal number of threads minus |num_threads + 1| is used (enter -1 to
36        use as many threads as possible). 0 raises an error.
37    normalize : boolean
38        Apply max-normalization to the output such that its value range matches
39        the input ODF data.
40
41    Returns
42    -------
43    output : array of double
44        The ODF data after convolution enhancement in spherical harmonics format
45
46    References
47    ----------
48    [Meesters2016_ISMRM] S. Meesters, G. Sanguinetti, E. Garyfallidis,
49                         J. Portegies, R. Duits. (2016) Fast implementations of
50                         contextual PDE’s for HARDI data processing in DIPY.
51                         ISMRM 2016 conference.
52    [DuitsAndFranken_IJCV] R. Duits and E. Franken (2011) Left-invariant diffusions
53                        on the space of positions and orientations and their
54                        application to crossing-preserving smoothing of HARDI
55                        images. International Journal of Computer Vision, 92:231-264.
56    [Portegies2015] J. Portegies, G. Sanguinetti, S. Meesters, and R. Duits.
57                    (2015) New Approximation of a Scale Space Kernel on SE(3) and
58                    Applications in Neuroimaging. Fifth International
59                    Conference on Scale Space and Variational Methods in
60                    Computer Vision
61    [Portegies2015b] J. Portegies, R. Fick, G. Sanguinetti, S. Meesters, G.Girard,
62                     and R. Duits. (2015) Improving Fiber Alignment in HARDI by
63                     Combining Contextual PDE flow with Constrained Spherical
64                     Deconvolution. PLoS One.
65    """
66
67    # convert the ODFs from SH basis to DSF
68    sphere = kernel.get_sphere()
69    odfs_dsf = sh_to_sf(odfs_sh, sphere, sh_order=sh_order, basis_type=None)
70
71    # perform the convolution
72    output = perform_convolution(odfs_dsf,
73                                 kernel.get_lookup_table(),
74                                 test_mode,
75                                 num_threads)
76
77    # normalize the output
78    if normalize:
79        output = np.multiply(output, np.amax(odfs_dsf)/np.amax(output))
80
81    # convert back to SH
82    output_sh = sf_to_sh(output, sphere, sh_order=sh_order)
83
84    return output_sh
85
86def convolve_sf(odfs_sf, kernel, test_mode=False, num_threads=None, normalize=True):
87    """ Perform the shift-twist convolution with the ODF data and
88    the lookup-table of the kernel.
89
90    Parameters
91    ----------
92    odfs : array of double
93        The ODF data sampled on a sphere
94    kernel : array of double
95        The 5D lookup table
96    test_mode : boolean
97        Reduced convolution in one direction only for testing
98    num_threads : int, optional
99        Number of threads to be used for OpenMP parallelization. If None
100        (default) the value of OMP_NUM_THREADS environment variable is used
101        if it is set, otherwise all available threads are used. If < 0 the
102        maximal number of threads minus |num_threads + 1| is used (enter -1 to
103        use as many threads as possible). 0 raises an error.
104    normalize : boolean
105        Apply max-normalization to the output such that its value range matches
106        the input ODF data.
107
108    Returns
109    -------
110    output : array of double
111        The ODF data after convolution enhancement, sampled on a sphere
112    """
113    # perform the convolution
114    output = perform_convolution(odfs_sf,
115                                 kernel.get_lookup_table(),
116                                 test_mode,
117                                 num_threads)
118
119    # normalize the output
120    if normalize:
121        output = np.multiply(output, np.amax(odfs_sf)/np.amax(output))
122
123    return output
124
125@cython.wraparound(False)
126@cython.boundscheck(False)
127@cython.nonecheck(False)
128@cython.cdivision(True)
129cdef double [:, :, :, ::1] perform_convolution (double [:, :, :, ::1] odfs,
130                                                double [:, :, :, :, ::1] lut,
131                                                cnp.npy_intp test_mode,
132                                                num_threads=None):
133    """ Perform the shift-twist convolution with the ODF data
134    and the lookup-table of the kernel.
135
136    Parameters
137    ----------
138    odfs : array of double
139        The ODF data sampled on a sphere
140    lut : array of double
141        The 5D lookup table
142    test_mode : boolean
143        Reduced convolution in one direction only for testing
144    num_threads : int, optional
145        Number of threads to be used for OpenMP parallelization. If None
146        (default) the value of OMP_NUM_THREADS environment variable is used
147        if it is set, otherwise all available threads are used. If < 0 the
148        maximal number of threads minus |num_threads + 1| is used (enter -1 to
149        use as many threads as possible). 0 raises an error.
150
151    Returns
152    -------
153    output : array of double
154        The ODF data after convolution enhancement
155    """
156
157    cdef:
158        double [:, :, :, ::1] output = np.array(odfs, copy=True)
159        cnp.npy_intp OR1 = lut.shape[0]
160        cnp.npy_intp OR2 = lut.shape[1]
161        cnp.npy_intp N = lut.shape[2]
162        cnp.npy_intp hn = (N - 1) / 2
163        double [:, :, :, :] totalval
164        double [:, :, :, :] voxcount
165        cnp.npy_intp nx = odfs.shape[0]
166        cnp.npy_intp ny = odfs.shape[1]
167        cnp.npy_intp nz = odfs.shape[2]
168        cnp.npy_intp threads_to_use = -1
169        cnp.npy_intp all_cores = openmp.omp_get_num_procs()
170        cnp.npy_intp corient, orient, cx, cy, cz, x, y, z
171        cnp.npy_intp expectedvox
172        cnp.npy_intp edgeNormalization = True
173
174    threads_to_use = determine_num_threads(num_threads)
175    set_num_threads(threads_to_use)
176
177    if test_mode:
178        edgeNormalization = False
179        OR2 = 1
180
181    # expected number of voxels in kernel
182    totalval = np.zeros((OR1, nx, ny, nz))
183    voxcount = np.zeros((OR1, nx, ny, nz))
184    expectedvox = nx * ny * nz
185
186    with nogil:
187
188        # loop over ODFs cx,cy,cz,orient --> y and v
189        for corient in prange(OR1, schedule='guided'):
190            for cx in range(nx):
191                for cy in range(ny):
192                    for cz in range(nz):
193                        # loop over kernel x,y,z,orient --> x and r
194                        for x in range(int_max(cx - hn, 0),
195                                       int_min(cx + hn + 1, ny - 1)):
196                             for y in range(int_max(cy - hn, 0),
197                                            int_min(cy + hn + 1, ny - 1)):
198                                 for z in range(int_max(cz - hn, 0),
199                                                int_min(cz + hn + 1, nz - 1)):
200                                    voxcount[corient, cx, cy, cz] += 1.0
201                                    for orient in range(0, OR2):
202                                        totalval[corient, cx, cy, cz] += \
203                                            odfs[x, y, z, orient] * \
204                                            lut[corient, orient, x - (cx - hn), y - (cy - hn), z - (cz - hn)]
205                        if edgeNormalization:
206                            output[cx, cy, cz, corient] = \
207                                totalval[corient, cx, cy, cz] * expectedvox/voxcount[corient, cx, cy, cz]
208                        else:
209                            output[cx, cy, cz, corient] = \
210                                totalval[corient, cx, cy, cz]
211
212    # Reset number of OpenMP cores to default
213    if num_threads is not None:
214        restore_default_num_threads()
215
216    return output
217
218cdef inline cnp.npy_intp int_max(cnp.npy_intp a, cnp.npy_intp b) nogil:
219    return a if a >= b else b
220cdef inline cnp.npy_intp int_min(cnp.npy_intp a, cnp.npy_intp b) nogil:
221    return a if a <= b else b
222