1""" Utility functions used by the Cross Correlation (CC) metric """
2
3import numpy as np
4from dipy.align.fused_types cimport floating
5cimport cython
6cimport numpy as cnp
7
8
9cdef inline int _int_max(int a, int b) nogil:
10    r"""
11    Returns the maximum of a and b
12    """
13    return a if a >= b else b
14
15
16cdef inline int _int_min(int a, int b) nogil:
17    r"""
18    Returns the minimum of a and b
19    """
20    return a if a <= b else b
21
22
23cdef enum:
24    SI = 0
25    SI2 = 1
26    SJ = 2
27    SJ2 = 3
28    SIJ = 4
29    CNT = 5
30
31
32@cython.boundscheck(False)
33@cython.wraparound(False)
34@cython.cdivision(True)
35cdef inline int _wrap(int x, int m)nogil:
36    r""" Auxiliary function to `wrap` an array around its low-end side.
37    Negative indices are mapped to last coordinates so that no extra memory
38    is required to account for local rectangular windows that exceed the
39    array's low-end boundary.
40
41    Parameters
42    ----------
43    x : int
44        the array position to be wrapped
45    m : int
46        array length
47    """
48    if x < 0:
49        return x + m
50    return x
51
52
53@cython.boundscheck(False)
54@cython.wraparound(False)
55@cython.cdivision(True)
56cdef inline void _update_factors(double[:, :, :, :] factors,
57                                 floating[:, :, :] moving,
58                                 floating[:, :, :] static,
59                                 cnp.npy_intp ss, cnp.npy_intp rr, cnp.npy_intp cc,
60                                 cnp.npy_intp s, cnp.npy_intp r, cnp.npy_intp c, int operation)nogil:
61    r"""Updates the precomputed CC factors of a rectangular window
62
63    Updates the precomputed CC factors of the rectangular window centered
64    at (`ss`, `rr`, `cc`) by adding the factors corresponding to voxel
65    (`s`, `r`, `c`) of input images `moving` and `static`.
66
67    Parameters
68    ----------
69    factors : array, shape (S, R, C, 5)
70        array containing the current precomputed factors to be updated
71    moving : array, shape (S, R, C)
72        the moving volume (notice that both images must already be in a common
73        reference domain, in particular, they must have the same shape)
74    static : array, shape (S, R, C)
75        the static volume, which also defines the reference registration domain
76    ss : int
77        first coordinate of the rectangular window to be updated
78    rr : int
79        second coordinate of the rectangular window to be updated
80    cc : int
81        third coordinate of the rectangular window to be updated
82    s: int
83        first coordinate of the voxel the local window should be updated with
84    r: int
85        second coordinate of the voxel the local window should be updated with
86    c: int
87        third coordinate of the voxel the local window should be updated with
88    operation : int, either -1, 0 or 1
89        indicates whether the factors of voxel (`s`, `r`, `c`) should be
90        added to (`operation`=1), subtracted from (`operation`=-1), or set as
91        (`operation`=0) the current factors for the rectangular window centered
92        at (`ss`, `rr`, `cc`).
93
94    """
95    cdef:
96        double sval
97        double mval
98    if s >= moving.shape[0] or r >= moving.shape[1] or c >= moving.shape[2]:
99        if operation == 0:
100            factors[ss, rr, cc, SI] = 0
101            factors[ss, rr, cc, SI2] = 0
102            factors[ss, rr, cc, SJ] = 0
103            factors[ss, rr, cc, SJ2] = 0
104            factors[ss, rr, cc, SIJ] = 0
105    else:
106        sval = static[s, r, c]
107        mval = moving[s, r, c]
108        if operation == 0:
109            factors[ss, rr, cc, SI] = sval
110            factors[ss, rr, cc, SI2] = sval*sval
111            factors[ss, rr, cc, SJ] = mval
112            factors[ss, rr, cc, SJ2] = mval*mval
113            factors[ss, rr, cc, SIJ] = sval*mval
114        elif operation == -1:
115            factors[ss, rr, cc, SI] -= sval
116            factors[ss, rr, cc, SI2] -= sval*sval
117            factors[ss, rr, cc, SJ] -= mval
118            factors[ss, rr, cc, SJ2] -= mval*mval
119            factors[ss, rr, cc, SIJ] -= sval*mval
120        elif operation == 1:
121            factors[ss, rr, cc, SI] += sval
122            factors[ss, rr, cc, SI2] += sval*sval
123            factors[ss, rr, cc, SJ] += mval
124            factors[ss, rr, cc, SJ2] += mval*mval
125            factors[ss, rr, cc, SIJ] += sval*mval
126
127
128@cython.boundscheck(False)
129@cython.wraparound(False)
130@cython.cdivision(True)
131def precompute_cc_factors_3d(floating[:, :, :] static,
132                             floating[:, :, :] moving,
133                             cnp.npy_intp radius, num_threads=None):
134    r"""Precomputations to quickly compute the gradient of the CC Metric
135
136    Pre-computes the separate terms of the cross correlation metric and image
137    norms at each voxel considering a neighborhood of the given radius to
138    efficiently compute the gradient of the metric with respect to the
139    deformation field [Ocegueda2016]_ [Avants2008]_ [Avants2011]_.
140
141    Parameters
142    ----------
143    static : array, shape (S, R, C)
144        the static volume, which also defines the reference registration domain
145    moving : array, shape (S, R, C)
146        the moving volume (notice that both images must already be in a common
147        reference domain, i.e. the same S, R, C)
148    radius : the radius of the neighborhood (cube of (2 * radius + 1)^3 voxels)
149
150    Returns
151    -------
152    factors : array, shape (S, R, C, 5)
153        the precomputed cross correlation terms:
154        factors[:,:,:,0] : static minus its mean value along the neighborhood
155        factors[:,:,:,1] : moving minus its mean value along the neighborhood
156        factors[:,:,:,2] : sum of the pointwise products of static and moving
157                           along the neighborhood
158        factors[:,:,:,3] : sum of sq. values of static along the neighborhood
159        factors[:,:,:,4] : sum of sq. values of moving along the neighborhood
160
161    References
162    ----------
163    .. [Ocegueda2016]_ Ocegueda, O., Dalmau, O., Garyfallidis, E., Descoteaux,
164        M., & Rivera, M. (2016). On the computation of integrals over
165        fixed-size rectangles of arbitrary dimension, Pattern Recognition
166        Letters. doi:10.1016/j.patrec.2016.05.008
167    .. [Avants2008]_ Avants, B. B., Epstein, C. L., Grossman, M., & Gee, J. C.
168        (2008). Symmetric Diffeomorphic Image Registration with
169        Cross-Correlation: Evaluating Automated Labeling of Elderly and
170        Neurodegenerative Brain, Med Image Anal. 12(1), 26-41.
171    .. [Avants2011]_ Avants, B. B., Tustison, N., & Song, G. (2011). Advanced
172        Normalization Tools (ANTS), 1-35.
173    """
174    cdef:
175        cnp.npy_intp ns = static.shape[0]
176        cnp.npy_intp nr = static.shape[1]
177        cnp.npy_intp nc = static.shape[2]
178        cnp.npy_intp side = 2 * radius + 1
179        cnp.npy_intp firstc, lastc, firstr, lastr, firsts, lasts
180        cnp.npy_intp s, r, c, it, sides, sider, sidec
181        double cnt
182        cnp.npy_intp ssss, sss, ss, rr, cc, prev_ss, prev_rr, prev_cc
183        double Imean, Jmean, IJprods, Isq, Jsq
184        double[:, :, :, :] temp = np.zeros((2, nr, nc, 5), dtype=np.float64)
185        floating[:, :, :, :] factors = np.zeros((ns, nr, nc, 5),
186                                                dtype=np.asarray(static).dtype)
187
188    with nogil:
189        sss = 1
190        for s in range(ns+radius):
191            ss = _wrap(s - radius, ns)
192            sss = 1 - sss
193            firsts = _int_max(0, ss - radius)
194            lasts = _int_min(ns - 1, ss + radius)
195            sides = (lasts - firsts + 1)
196            for r in range(nr+radius):
197                rr = _wrap(r - radius, nr)
198                firstr = _int_max(0, rr - radius)
199                lastr = _int_min(nr - 1, rr + radius)
200                sider = (lastr - firstr + 1)
201                for c in range(nc+radius):
202                    cc = _wrap(c - radius, nc)
203                    # New corner
204                    _update_factors(temp, moving, static,
205                                    sss, rr, cc, s, r, c, 0)
206
207                    # Add signed sub-volumes
208                    if s > 0:
209                        prev_ss = 1 - sss
210                        for it in range(5):
211                            temp[sss, rr, cc, it] += temp[prev_ss, rr, cc, it]
212                        if r > 0:
213                            prev_rr = _wrap(rr-1, nr)
214                            for it in range(5):
215                                temp[sss, rr, cc, it] -= \
216                                    temp[prev_ss, prev_rr, cc, it]
217                            if c > 0:
218                                prev_cc = _wrap(cc-1, nc)
219                                for it in range(5):
220                                    temp[sss, rr, cc, it] += \
221                                        temp[prev_ss, prev_rr, prev_cc, it]
222                        if c > 0:
223                            prev_cc = _wrap(cc-1, nc)
224                            for it in range(5):
225                                temp[sss, rr, cc, it] -= \
226                                    temp[prev_ss, rr, prev_cc, it]
227                    if(r > 0):
228                        prev_rr = _wrap(rr-1, nr)
229                        for it in range(5):
230                            temp[sss, rr, cc, it] += \
231                                temp[sss, prev_rr, cc, it]
232                        if(c > 0):
233                            prev_cc = _wrap(cc-1, nc)
234                            for it in range(5):
235                                temp[sss, rr, cc, it] -= \
236                                    temp[sss, prev_rr, prev_cc, it]
237                    if(c > 0):
238                        prev_cc = _wrap(cc-1, nc)
239                        for it in range(5):
240                            temp[sss, rr, cc, it] += temp[sss, rr, prev_cc, it]
241
242                    # Add signed corners
243                    if s >= side:
244                        _update_factors(temp, moving, static,
245                                        sss, rr, cc, s-side, r, c, -1)
246                        if r >= side:
247                            _update_factors(temp, moving, static,
248                                            sss, rr, cc, s-side, r-side, c, 1)
249                            if c >= side:
250                                _update_factors(temp, moving, static, sss, rr,
251                                                cc, s-side, r-side, c-side, -1)
252                        if c >= side:
253                            _update_factors(temp, moving, static,
254                                            sss, rr, cc, s-side, r, c-side, 1)
255                    if r >= side:
256                        _update_factors(temp, moving, static,
257                                        sss, rr, cc, s, r-side, c, -1)
258                        if c >= side:
259                            _update_factors(temp, moving, static,
260                                            sss, rr, cc, s, r-side, c-side, 1)
261
262                    if c >= side:
263                        _update_factors(temp, moving, static,
264                                        sss, rr, cc, s, r, c-side, -1)
265                    # Compute final factors
266                    if s >= radius and r >= radius and c >= radius:
267                        firstc = _int_max(0, cc - radius)
268                        lastc = _int_min(nc - 1, cc + radius)
269                        sidec = (lastc - firstc + 1)
270                        cnt = sides*sider*sidec
271                        Imean = temp[sss, rr, cc, SI] / cnt
272                        Jmean = temp[sss, rr, cc, SJ] / cnt
273                        IJprods = (temp[sss, rr, cc, SIJ] -
274                                   Jmean * temp[sss, rr, cc, SI] -
275                                   Imean * temp[sss, rr, cc, SJ] +
276                                   cnt * Jmean * Imean)
277                        Isq = (temp[sss, rr, cc, SI2] -
278                               Imean * temp[sss, rr, cc, SI] -
279                               Imean * temp[sss, rr, cc, SI] +
280                               cnt * Imean * Imean)
281                        Jsq = (temp[sss, rr, cc, SJ2] -
282                               Jmean * temp[sss, rr, cc, SJ] -
283                               Jmean * temp[sss, rr, cc, SJ] +
284                               cnt * Jmean * Jmean)
285                        factors[ss, rr, cc, 0] = static[ss, rr, cc] - Imean
286                        factors[ss, rr, cc, 1] = moving[ss, rr, cc] - Jmean
287                        factors[ss, rr, cc, 2] = IJprods
288                        factors[ss, rr, cc, 3] = Isq
289                        factors[ss, rr, cc, 4] = Jsq
290    return factors
291
292
293@cython.boundscheck(False)
294@cython.wraparound(False)
295@cython.cdivision(True)
296def precompute_cc_factors_3d_test(floating[:, :, :] static,
297                                  floating[:, :, :] moving, int radius):
298    r"""Precomputations to quickly compute the gradient of the CC Metric
299
300    This version of precompute_cc_factors_3d is for testing purposes, it
301    directly computes the local cross-correlation factors without any
302    optimization, so it is less error-prone than the accelerated version.
303    """
304    cdef:
305        cnp.npy_intp ns = static.shape[0]
306        cnp.npy_intp nr = static.shape[1]
307        cnp.npy_intp nc = static.shape[2]
308        cnp.npy_intp s, r, c, k, i, j, t
309        cnp.npy_intp firstc, lastc, firstr, lastr, firsts, lasts
310        double Imean, Jmean
311        floating[:, :, :, :] factors = np.zeros((ns, nr, nc, 5),
312                                                dtype=np.asarray(static).dtype)
313        double[:] sums = np.zeros((6,), dtype=np.float64)
314
315    with nogil:
316        for s in range(ns):
317            firsts = _int_max(0, s - radius)
318            lasts = _int_min(ns - 1, s + radius)
319            for r in range(nr):
320                firstr = _int_max(0, r - radius)
321                lastr = _int_min(nr - 1, r + radius)
322                for c in range(nc):
323                    firstc = _int_max(0, c - radius)
324                    lastc = _int_min(nc - 1, c + radius)
325                    for t in range(6):
326                        sums[t] = 0
327                    for k in range(firsts, 1 + lasts):
328                        for i in range(firstr, 1 + lastr):
329                            for j in range(firstc, 1 + lastc):
330                                sums[SI] += static[k, i, j]
331                                sums[SI2] += static[k, i, j]**2
332                                sums[SJ] += moving[k, i, j]
333                                sums[SJ2] += moving[k, i, j]**2
334                                sums[SIJ] += static[k, i, j]*moving[k, i, j]
335                                sums[CNT] += 1
336                    Imean = sums[SI] / sums[CNT]
337                    Jmean = sums[SJ] / sums[CNT]
338                    factors[s, r, c, 0] = static[s, r, c] - Imean
339                    factors[s, r, c, 1] = moving[s, r, c] - Jmean
340                    factors[s, r, c, 2] = (sums[SIJ] - Jmean * sums[SI] -
341                                           Imean * sums[SJ] +
342                                           sums[CNT] * Jmean * Imean)
343                    factors[s, r, c, 3] = (sums[SI2] - Imean * sums[SI] -
344                                           Imean * sums[SI] +
345                                           sums[CNT] * Imean * Imean)
346                    factors[s, r, c, 4] = (sums[SJ2] - Jmean * sums[SJ] -
347                                           Jmean * sums[SJ] +
348                                           sums[CNT] * Jmean * Jmean)
349    return np.asarray(factors)
350
351
352@cython.boundscheck(False)
353@cython.wraparound(False)
354@cython.cdivision(True)
355def compute_cc_forward_step_3d(floating[:, :, :, :] grad_static,
356                               floating[:, :, :, :] factors,
357                               cnp.npy_intp radius):
358    r"""Gradient of the CC Metric w.r.t. the forward transformation
359
360    Computes the gradient of the Cross Correlation metric for symmetric
361    registration (SyN) [Avants2008]_ w.r.t. the displacement associated to
362    the moving volume ('forward' step) as in [Avants2011]_
363
364    Parameters
365    ----------
366    grad_static : array, shape (S, R, C, 3)
367        the gradient of the static volume
368    factors : array, shape (S, R, C, 5)
369        the precomputed cross correlation terms obtained via
370        precompute_cc_factors_3d
371    radius : int
372        the radius of the neighborhood used for the CC metric when
373        computing the factors. The returned vector field will be
374        zero along a boundary of width radius voxels.
375
376    Returns
377    -------
378    out : array, shape (S, R, C, 3)
379        the gradient of the cross correlation metric with respect to the
380        displacement associated to the moving volume
381    energy : the cross correlation energy (data term) at this iteration
382
383    References
384    ----------
385    .. [Avants2008]_ Avants, B. B., Epstein, C. L., Grossman, M., & Gee, J. C.
386        (2008). Symmetric Diffeomorphic Image Registration with
387        Cross-Correlation: Evaluating Automated Labeling of Elderly and
388        Neurodegenerative Brain, Med Image Anal. 12(1), 26-41.
389    .. [Avants2011]_ Avants, B. B., Tustison, N., & Song, G. (2011). Advanced
390        Normalization Tools (ANTS), 1-35.
391    """
392    cdef:
393        cnp.npy_intp ns = grad_static.shape[0]
394        cnp.npy_intp nr = grad_static.shape[1]
395        cnp.npy_intp nc = grad_static.shape[2]
396        double energy = 0
397        cnp.npy_intp s, r, c
398        double Ii, Ji, sfm, sff, smm, localCorrelation, temp
399        floating[:, :, :, :] out =\
400            np.zeros((ns, nr, nc, 3), dtype=np.asarray(grad_static).dtype)
401    with nogil:
402        for s in range(radius, ns-radius):
403            for r in range(radius, nr-radius):
404                for c in range(radius, nc-radius):
405                    Ii = factors[s, r, c, 0]
406                    Ji = factors[s, r, c, 1]
407                    sfm = factors[s, r, c, 2]
408                    sff = factors[s, r, c, 3]
409                    smm = factors[s, r, c, 4]
410                    if(sff == 0.0 or smm == 0.0):
411                        continue
412                    localCorrelation = 0
413                    if(sff * smm > 1e-5):
414                        localCorrelation = sfm * sfm / (sff * smm)
415                    if(localCorrelation < 1):  # avoid bad values...
416                        energy -= localCorrelation
417                    temp = 2.0 * sfm / (sff * smm) * (Ji - sfm / sff * Ii)
418                    out[s, r, c, 0] -= temp * grad_static[s, r, c, 0]
419                    out[s, r, c, 1] -= temp * grad_static[s, r, c, 1]
420                    out[s, r, c, 2] -= temp * grad_static[s, r, c, 2]
421    return np.asarray(out), energy
422
423
424@cython.boundscheck(False)
425@cython.wraparound(False)
426@cython.cdivision(True)
427def compute_cc_backward_step_3d(floating[:, :, :, :] grad_moving,
428                                floating[:, :, :, :] factors,
429                                cnp.npy_intp radius):
430    r"""Gradient of the CC Metric w.r.t. the backward transformation
431
432    Computes the gradient of the Cross Correlation metric for symmetric
433    registration (SyN) [Avants08]_ w.r.t. the displacement associated to
434    the static volume ('backward' step) as in [Avants11]_
435
436    Parameters
437    ----------
438    grad_moving : array, shape (S, R, C, 3)
439        the gradient of the moving volume
440    factors : array, shape (S, R, C, 5)
441        the precomputed cross correlation terms obtained via
442        precompute_cc_factors_3d
443    radius : int
444        the radius of the neighborhood used for the CC metric when
445        computing the factors. The returned vector field will be
446        zero along a boundary of width radius voxels.
447
448    Returns
449    -------
450    out : array, shape (S, R, C, 3)
451        the gradient of the cross correlation metric with respect to the
452        displacement associated to the static volume
453    energy : the cross correlation energy (data term) at this iteration
454
455    References
456    ----------
457    [Avants08]_ Avants, B. B., Epstein, C. L., Grossman, M., & Gee, J. C. (2008)
458               Symmetric Diffeomorphic Image Registration with
459               Cross-Correlation: Evaluating Automated Labeling of Elderly and
460               Neurodegenerative Brain, Med Image Anal. 12(1), 26-41.
461    [Avants11]_ Avants, B. B., Tustison, N., & Song, G. (2011).
462               Advanced Normalization Tools (ANTS), 1-35.
463    """
464    ftype = np.asarray(grad_moving).dtype
465    cdef:
466        cnp.npy_intp ns = grad_moving.shape[0]
467        cnp.npy_intp nr = grad_moving.shape[1]
468        cnp.npy_intp nc = grad_moving.shape[2]
469        cnp.npy_intp s, r, c
470        double energy = 0
471        double Ii, Ji, sfm, sff, smm, localCorrelation, temp
472        floating[:, :, :, :] out = np.zeros((ns, nr, nc, 3), dtype=ftype)
473
474    with nogil:
475
476        for s in range(radius, ns-radius):
477            for r in range(radius, nr-radius):
478                for c in range(radius, nc-radius):
479                    Ii = factors[s, r, c, 0]
480                    Ji = factors[s, r, c, 1]
481                    sfm = factors[s, r, c, 2]
482                    sff = factors[s, r, c, 3]
483                    smm = factors[s, r, c, 4]
484                    if(sff == 0.0 or smm == 0.0):
485                        continue
486                    localCorrelation = 0
487                    if(sff * smm > 1e-5):
488                        localCorrelation = sfm * sfm / (sff * smm)
489                    if(localCorrelation < 1):  # avoid bad values...
490                        energy -= localCorrelation
491                    temp = 2.0 * sfm / (sff * smm) * (Ii - sfm / smm * Ji)
492                    out[s, r, c, 0] -= temp * grad_moving[s, r, c, 0]
493                    out[s, r, c, 1] -= temp * grad_moving[s, r, c, 1]
494                    out[s, r, c, 2] -= temp * grad_moving[s, r, c, 2]
495    return np.asarray(out), energy
496
497
498@cython.boundscheck(False)
499@cython.wraparound(False)
500@cython.cdivision(True)
501def precompute_cc_factors_2d(floating[:, :] static, floating[:, :] moving,
502                             cnp.npy_intp radius):
503    r"""Precomputations to quickly compute the gradient of the CC Metric
504
505    Pre-computes the separate terms of the cross correlation metric
506    [Avants2008]_ and image norms at each voxel considering a neighborhood of
507    the given radius to efficiently [Avants2011]_ compute the gradient of the
508    metric with respect to the deformation field.
509
510    Parameters
511    ----------
512    static : array, shape (R, C)
513        the static volume, which also defines the reference registration domain
514    moving : array, shape (R, C)
515        the moving volume (notice that both images must already be in a common
516        reference domain, i.e. the same R, C)
517    radius : the radius of the neighborhood(square of (2*radius + 1)^2 voxels)
518
519    Returns
520    -------
521    factors : array, shape (R, C, 5)
522        the precomputed cross correlation terms:
523        factors[:,:,0] : static minus its mean value along the neighborhood
524        factors[:,:,1] : moving minus its mean value along the neighborhood
525        factors[:,:,2] : sum of the pointwise products of static and moving
526                           along the neighborhood
527        factors[:,:,3] : sum of sq. values of static along the neighborhood
528        factors[:,:,4] : sum of sq. values of moving along the neighborhood
529
530    References
531    ----------
532    .. [Avants2008]_ Avants, B. B., Epstein, C. L., Grossman, M., & Gee, J. C.
533        (2008). Symmetric Diffeomorphic Image Registration with
534        Cross-Correlation: Evaluating Automated Labeling of Elderly and
535        Neurodegenerative Brain, Med Image Anal. 12(1), 26-41.
536    .. [Avants2011]_ Avants, B. B., Tustison, N., & Song, G. (2011). Advanced
537        Normalization Tools (ANTS), 1-35.
538    """
539    ftype = np.asarray(static).dtype
540    cdef:
541        cnp.npy_intp side = 2 * radius + 1
542        cnp.npy_intp nr = static.shape[0]
543        cnp.npy_intp nc = static.shape[1]
544        cnp.npy_intp r, c, i, j, t, q, qq, firstc, lastc
545        double Imean, Jmean
546        floating[:, :, :] factors = np.zeros((nr, nc, 5), dtype=ftype)
547        double[:, :] lines = np.zeros((6, side), dtype=np.float64)
548        double[:] sums = np.zeros((6,), dtype=np.float64)
549
550    with nogil:
551
552        for c in range(nc):
553            firstc = _int_max(0, c - radius)
554            lastc = _int_min(nc - 1, c + radius)
555            # compute factors for row [:,c]
556            for t in range(6):
557                for q in range(side):
558                    lines[t, q] = 0
559            # Compute all rows and set the sums on the fly
560            # compute row [i, j = {c-radius, c + radius}]
561            for i in range(nr):
562                q = i % side
563                for t in range(6):
564                    lines[t, q] = 0
565                for j in range(firstc, lastc + 1):
566                    lines[SI, q] += static[i, j]
567                    lines[SI2, q] += static[i, j] * static[i, j]
568                    lines[SJ, q] += moving[i, j]
569                    lines[SJ2, q] += moving[i, j] * moving[i, j]
570                    lines[SIJ, q] += static[i, j] * moving[i, j]
571                    lines[CNT, q] += 1
572
573                for t in range(6):
574                    sums[t] = 0
575                    for qq in range(side):
576                        sums[t] += lines[t, qq]
577                if(i >= radius):
578                    # r is the pixel that is affected by the cube with slices
579                    # [r - radius.. r + radius, :]
580                    r = i - radius
581                    Imean = sums[SI] / sums[CNT]
582                    Jmean = sums[SJ] / sums[CNT]
583                    factors[r, c, 0] = static[r, c] - Imean
584                    factors[r, c, 1] = moving[r, c] - Jmean
585                    factors[r, c, 2] = (sums[SIJ] - Jmean * sums[SI] -
586                                        Imean * sums[SJ] +
587                                        sums[CNT] * Jmean * Imean)
588                    factors[r, c, 3] = (sums[SI2] - Imean * sums[SI] -
589                                        Imean * sums[SI] +
590                                        sums[CNT] * Imean * Imean)
591                    factors[r, c, 4] = (sums[SJ2] - Jmean * sums[SJ] -
592                                        Jmean * sums[SJ] +
593                                        sums[CNT] * Jmean * Jmean)
594            # Finally set the values at the end of the line
595            for r in range(nr - radius, nr):
596                # this would be the last slice to be processed for pixel
597                # [r, c], if it existed
598                i = r + radius
599                q = i % side
600                for t in range(6):
601                    sums[t] -= lines[t, q]
602                Imean = sums[SI] / sums[CNT]
603                Jmean = sums[SJ] / sums[CNT]
604                factors[r, c, 0] = static[r, c] - Imean
605                factors[r, c, 1] = moving[r, c] - Jmean
606                factors[r, c, 2] = (sums[SIJ] - Jmean * sums[SI] -
607                                    Imean * sums[SJ] +
608                                    sums[CNT] * Jmean * Imean)
609                factors[r, c, 3] = (sums[SI2] - Imean * sums[SI] -
610                                    Imean * sums[SI] +
611                                    sums[CNT] * Imean * Imean)
612                factors[r, c, 4] = (sums[SJ2] - Jmean * sums[SJ] -
613                                    Jmean * sums[SJ] +
614                                    sums[CNT] * Jmean * Jmean)
615    return np.asarray(factors)
616
617
618@cython.boundscheck(False)
619@cython.wraparound(False)
620@cython.cdivision(True)
621def precompute_cc_factors_2d_test(floating[:, :] static, floating[:, :] moving,
622                                  cnp.npy_intp radius):
623    r"""Precomputations to quickly compute the gradient of the CC Metric
624
625    This version of precompute_cc_factors_2d is for testing purposes, it
626    directly computes the local cross-correlation without any optimization.
627    """
628    ftype = np.asarray(static).dtype
629    cdef:
630        cnp.npy_intp nr = static.shape[0]
631        cnp.npy_intp nc = static.shape[1]
632        cnp.npy_intp r, c, i, j, t, firstr, lastr, firstc, lastc
633        double Imean, Jmean
634        floating[:, :, :] factors = np.zeros((nr, nc, 5), dtype=ftype)
635        double[:] sums = np.zeros((6,), dtype=np.float64)
636
637    with nogil:
638
639        for r in range(nr):
640            firstr = _int_max(0, r - radius)
641            lastr = _int_min(nr - 1, r + radius)
642            for c in range(nc):
643                firstc = _int_max(0, c - radius)
644                lastc = _int_min(nc - 1, c + radius)
645                for t in range(6):
646                    sums[t] = 0
647                for i in range(firstr, 1 + lastr):
648                    for j in range(firstc, 1+lastc):
649                        sums[SI] += static[i, j]
650                        sums[SI2] += static[i, j]**2
651                        sums[SJ] += moving[i, j]
652                        sums[SJ2] += moving[i, j]**2
653                        sums[SIJ] += static[i, j]*moving[i, j]
654                        sums[CNT] += 1
655                Imean = sums[SI] / sums[CNT]
656                Jmean = sums[SJ] / sums[CNT]
657                factors[r, c, 0] = static[r, c] - Imean
658                factors[r, c, 1] = moving[r, c] - Jmean
659                factors[r, c, 2] = (sums[SIJ] - Jmean * sums[SI] -
660                                    Imean * sums[SJ] +
661                                    sums[CNT] * Jmean * Imean)
662                factors[r, c, 3] = (sums[SI2] - Imean * sums[SI] -
663                                    Imean * sums[SI] +
664                                    sums[CNT] * Imean * Imean)
665                factors[r, c, 4] = (sums[SJ2] - Jmean * sums[SJ] -
666                                    Jmean * sums[SJ] +
667                                    sums[CNT] * Jmean * Jmean)
668    return np.asarray(factors)
669
670
671@cython.boundscheck(False)
672@cython.wraparound(False)
673@cython.cdivision(True)
674def compute_cc_forward_step_2d(floating[:, :, :] grad_static,
675                               floating[:, :, :] factors,
676                               cnp.npy_intp radius):
677    r"""Gradient of the CC Metric w.r.t. the forward transformation
678
679    Computes the gradient of the Cross Correlation metric for symmetric
680    registration (SyN) [Avants2008]_ w.r.t. the displacement associated to
681    the moving image ('backward' step) as in [Avants2011]_
682
683    Parameters
684    ----------
685    grad_static : array, shape (R, C, 2)
686        the gradient of the static image
687    factors : array, shape (R, C, 5)
688        the precomputed cross correlation terms obtained via
689        precompute_cc_factors_2d
690
691    Returns
692    -------
693    out : array, shape (R, C, 2)
694        the gradient of the cross correlation metric with respect to the
695        displacement associated to the moving image
696    energy : the cross correlation energy (data term) at this iteration
697
698    Notes
699    -----
700    Currently, the gradient of the static image is not being used, but some
701    authors suggest that symmetrizing the gradient by including both, the
702    moving and static gradients may improve the registration quality. We are
703    leaving this parameters as a placeholder for future investigation
704
705    References
706    ----------
707    .. [Avants2008]_ Avants, B. B., Epstein, C. L., Grossman, M., & Gee, J. C.
708        (2008). Symmetric Diffeomorphic Image Registration with
709        Cross-Correlation: Evaluating Automated Labeling of Elderly and
710        Neurodegenerative Brain, Med Image Anal. 12(1), 26-41.
711    .. [Avants2011]_ Avants, B. B., Tustison, N., & Song, G. (2011). Advanced
712        Normalization Tools (ANTS), 1-35.
713    """
714    cdef:
715        cnp.npy_intp nr = grad_static.shape[0]
716        cnp.npy_intp nc = grad_static.shape[1]
717        double energy = 0
718        cnp.npy_intp r, c
719        double Ii, Ji, sfm, sff, smm, localCorrelation, temp
720        floating[:, :, :] out = np.zeros((nr, nc, 2),
721                                         dtype=np.asarray(grad_static).dtype)
722    with nogil:
723
724        for r in range(radius, nr-radius):
725            for c in range(radius, nc-radius):
726                Ii = factors[r, c, 0]
727                Ji = factors[r, c, 1]
728                sfm = factors[r, c, 2]
729                sff = factors[r, c, 3]
730                smm = factors[r, c, 4]
731                if(sff == 0.0 or smm == 0.0):
732                    continue
733                localCorrelation = 0
734                if(sff * smm > 1e-5):
735                    localCorrelation = sfm * sfm / (sff * smm)
736                if(localCorrelation < 1):  # avoid bad values...
737                    energy -= localCorrelation
738                temp = 2.0 * sfm / (sff * smm) * (Ji - sfm / sff * Ii)
739                out[r, c, 0] -= temp * grad_static[r, c, 0]
740                out[r, c, 1] -= temp * grad_static[r, c, 1]
741    return np.asarray(out), energy
742
743
744@cython.boundscheck(False)
745@cython.wraparound(False)
746@cython.cdivision(True)
747def compute_cc_backward_step_2d(floating[:, :, :] grad_moving,
748                                floating[:, :, :] factors,
749                                cnp.npy_intp radius):
750    r"""Gradient of the CC Metric w.r.t. the backward transformation
751
752    Computes the gradient of the Cross Correlation metric for symmetric
753    registration (SyN) [Avants2008]_ w.r.t. the displacement associated to
754    the static image ('forward' step) as in [Avants2011]_
755
756    Parameters
757    ----------
758    grad_moving : array, shape (R, C, 2)
759        the gradient of the moving image
760    factors : array, shape (R, C, 5)
761        the precomputed cross correlation terms obtained via
762        precompute_cc_factors_2d
763
764    Returns
765    -------
766    out : array, shape (R, C, 2)
767        the gradient of the cross correlation metric with respect to the
768        displacement associated to the static image
769    energy : the cross correlation energy (data term) at this iteration
770
771    References
772    ----------
773    .. [Avants2008]_ Avants, B. B., Epstein, C. L., Grossman, M., & Gee, J. C.
774        (2008). Symmetric Diffeomorphic Image Registration with
775        Cross-Correlation: Evaluating Automated Labeling of Elderly and
776        Neurodegenerative Brain, Med Image Anal. 12(1), 26-41.
777    .. [Avants2011]_ Avants, B. B., Tustison, N., & Song, G. (2011). Advanced
778        Normalization Tools (ANTS), 1-35.
779    """
780    ftype = np.asarray(grad_moving).dtype
781    cdef:
782        cnp.npy_intp nr = grad_moving.shape[0]
783        cnp.npy_intp nc = grad_moving.shape[1]
784        cnp.npy_intp r, c
785        double energy = 0
786        double Ii, Ji, sfm, sff, smm, localCorrelation, temp
787        floating[:, :, :] out = np.zeros((nr, nc, 2), dtype=ftype)
788
789    with nogil:
790
791        for r in range(radius, nr-radius):
792            for c in range(radius, nc-radius):
793                Ii = factors[r, c, 0]
794                Ji = factors[r, c, 1]
795                sfm = factors[r, c, 2]
796                sff = factors[r, c, 3]
797                smm = factors[r, c, 4]
798                if(sff == 0.0 or smm == 0.0):
799                    continue
800                localCorrelation = 0
801                if(sff * smm > 1e-5):
802                    localCorrelation = sfm * sfm / (sff * smm)
803                if(localCorrelation < 1):  # avoid bad values...
804                    energy -= localCorrelation
805                temp = 2.0 * sfm / (sff * smm) * (Ii - sfm / smm * Ji)
806                out[r, c, 0] -= temp * grad_moving[r, c, 0]
807                out[r, c, 1] -= temp * grad_moving[r, c, 1]
808    return np.asarray(out), energy
809