1"""  Metrics for Symmetric Diffeomorphic Registration """
2
3import abc
4import numpy as np
5import scipy as sp
6from numpy import gradient
7from scipy import ndimage
8from dipy.align import vector_fields as vfu
9from dipy.align import sumsqdiff as ssd
10from dipy.align import crosscorr as cc
11from dipy.align import expectmax as em
12from dipy.align import floating
13
14
15class SimilarityMetric(object, metaclass=abc.ABCMeta):
16    def __init__(self, dim):
17        r""" Similarity Metric abstract class
18
19        A similarity metric is in charge of keeping track of the numerical
20        value of the similarity (or distance) between the two given images. It
21        also computes the update field for the forward and inverse displacement
22        fields to be used in a gradient-based optimization algorithm. Note that
23        this metric does not depend on any transformation (affine or
24        non-linear) so it assumes the static and moving images are already
25        warped
26
27        Parameters
28        ----------
29        dim : int (either 2 or 3)
30            the dimension of the image domain
31        """
32        self.dim = dim
33        self.levels_above = None
34        self.levels_below = None
35
36        self.static_image = None
37        self.static_affine = None
38        self.static_spacing = None
39        self.static_direction = None
40
41        self.moving_image = None
42        self.moving_affine = None
43        self.moving_spacing = None
44        self.moving_direction = None
45        self.mask0 = False
46
47    def set_levels_below(self, levels):
48        r"""Informs the metric how many pyramid levels are below the current one
49
50        Informs this metric the number of pyramid levels below the current one.
51        The metric may change its behavior (e.g. number of inner iterations)
52        accordingly
53
54        Parameters
55        ----------
56        levels : int
57            the number of levels below the current Gaussian Pyramid level
58        """
59        self.levels_below = levels
60
61    def set_levels_above(self, levels):
62        r"""Informs the metric how many pyramid levels are above the current one
63
64        Informs this metric the number of pyramid levels above the current one.
65        The metric may change its behavior (e.g. number of inner iterations)
66        accordingly
67
68        Parameters
69        ----------
70        levels : int
71            the number of levels above the current Gaussian Pyramid level
72        """
73        self.levels_above = levels
74
75    def set_static_image(self, static_image, static_affine, static_spacing,
76                         static_direction):
77        r"""Sets the static image being compared against the moving one.
78
79        Sets the static image. The default behavior (of this abstract class) is
80        simply to assign the reference to an attribute, but
81        generalizations of the metric may need to perform other operations
82
83        Parameters
84        ----------
85        static_image : array, shape (R, C) or (S, R, C)
86            the static image
87        """
88        self.static_image = static_image
89        self.static_affine = static_affine
90        self.static_spacing = static_spacing
91        self.static_direction = static_direction
92
93    def use_static_image_dynamics(self, original_static_image, transformation):
94        r"""This is called by the optimizer just after setting the static image.
95
96        This method allows the metric to compute any useful
97        information from knowing how the current static image was generated
98        (as the transformation of an original static image). This method is
99        called by the optimizer just after it sets the static image.
100        Transformation will be an instance of DiffeomorficMap or None
101        if the original_static_image equals self.moving_image.
102
103        Parameters
104        ----------
105        original_static_image : array, shape (R, C) or (S, R, C)
106            original image from which the current static image was generated
107        transformation : DiffeomorphicMap object
108            the transformation that was applied to original image to generate
109            the current static image
110        """
111        pass
112
113    def set_moving_image(self, moving_image, moving_affine, moving_spacing,
114                         moving_direction):
115        r"""Sets the moving image being compared against the static one.
116
117        Sets the moving image. The default behavior (of this abstract class) is
118        simply to assign the reference to an attribute, but
119        generalizations of the metric may need to perform other operations
120
121        Parameters
122        ----------
123        moving_image : array, shape (R, C) or (S, R, C)
124            the moving image
125        """
126        self.moving_image = moving_image
127        self.moving_affine = moving_affine
128        self.moving_spacing = moving_spacing
129        self.moving_direction = moving_direction
130
131    def use_moving_image_dynamics(self, original_moving_image, transformation):
132        r"""This is called by the optimizer just after setting the moving image
133
134        This method allows the metric to compute any useful
135        information from knowing how the current static image was generated
136        (as the transformation of an original static image). This method is
137        called by the optimizer just after it sets the static image.
138        Transformation will be an instance of DiffeomorficMap or None if
139        the original_moving_image equals self.moving_image.
140
141        Parameters
142        ----------
143        original_moving_image : array, shape (R, C) or (S, R, C)
144            original image from which the current moving image was generated
145        transformation : DiffeomorphicMap object
146            the transformation that was applied to the original image to generate
147            the current moving image
148        """
149        pass
150
151    @abc.abstractmethod
152    def initialize_iteration(self):
153        r"""Prepares the metric to compute one displacement field iteration.
154
155        This method will be called before any compute_forward or
156        compute_backward call, this allows the Metric to pre-compute any useful
157        information for speeding up the update computations. This
158        initialization was needed in ANTS because the updates are called once
159        per voxel. In Python this is unpractical, though.
160        """
161
162    @abc.abstractmethod
163    def free_iteration(self):
164        r"""Releases the resources no longer needed by the metric
165
166        This method is called by the RegistrationOptimizer after the required
167        iterations have been computed (forward and / or backward) so that the
168        SimilarityMetric can safely delete any data it computed as part of the
169        initialization
170        """
171
172    @abc.abstractmethod
173    def compute_forward(self):
174        r"""Computes one step bringing the reference image towards the static.
175
176        Computes the forward update field to register the moving image towards
177        the static image in a gradient-based optimization algorithm
178        """
179
180    @abc.abstractmethod
181    def compute_backward(self):
182        r"""Computes one step bringing the static image towards the moving.
183
184        Computes the backward update field to register the static image towards
185        the moving image in a gradient-based optimization algorithm
186        """
187
188    @abc.abstractmethod
189    def get_energy(self):
190        r"""Numerical value assigned by this metric to the current image pair
191
192        Must return the numeric value of the similarity between the given
193        static and moving images
194        """
195
196
197class CCMetric(SimilarityMetric):
198
199    def __init__(self, dim, sigma_diff=2.0, radius=4):
200        r"""Normalized Cross-Correlation Similarity metric.
201
202        Parameters
203        ----------
204        dim : int (either 2 or 3)
205            the dimension of the image domain
206        sigma_diff : the standard deviation of the Gaussian smoothing kernel to
207            be applied to the update field at each iteration
208        radius : int
209            the radius of the squared (cubic) neighborhood at each voxel to be
210            considered to compute the cross correlation
211        """
212        super(CCMetric, self).__init__(dim)
213        self.sigma_diff = sigma_diff
214        self.radius = radius
215        self._connect_functions()
216
217    def _connect_functions(self):
218        r"""Assign the methods to be called according to the image dimension
219
220        Assigns the appropriate functions to be called for precomputing the
221        cross-correlation factors according to the dimension of the input
222        images
223        """
224        if self.dim == 2:
225            self.precompute_factors = cc.precompute_cc_factors_2d
226            self.compute_forward_step = cc.compute_cc_forward_step_2d
227            self.compute_backward_step = cc.compute_cc_backward_step_2d
228            self.reorient_vector_field = vfu.reorient_vector_field_2d
229        elif self.dim == 3:
230            self.precompute_factors = cc.precompute_cc_factors_3d
231            self.compute_forward_step = cc.compute_cc_forward_step_3d
232            self.compute_backward_step = cc.compute_cc_backward_step_3d
233            self.reorient_vector_field = vfu.reorient_vector_field_3d
234        else:
235            raise ValueError('CC Metric not defined for dim. %d' % (self.dim))
236
237    def initialize_iteration(self):
238        r"""Prepares the metric to compute one displacement field iteration.
239
240        Pre-computes the cross-correlation factors for efficient computation
241        of the gradient of the Cross Correlation w.r.t. the displacement field.
242        It also pre-computes the image gradients in the physical space by
243        re-orienting the gradients in the voxel space using the corresponding
244        affine transformations.
245        """
246
247        def invalid_image_size(image):
248            min_size = self.radius * 2 + 1
249            return any([size < min_size for size in image.shape])
250
251        msg = ("Each image dimension should be superior to 2 * radius + 1."
252               "Decrease CCMetric radius or increase your image size")
253
254        if invalid_image_size(self.static_image):
255            raise ValueError("Static image size is too small. " + msg)
256        if invalid_image_size(self.moving_image):
257            raise ValueError("Moving image size is too small. " + msg)
258
259        self.factors = self.precompute_factors(self.static_image,
260                                               self.moving_image,
261                                               self.radius)
262        self.factors = np.array(self.factors)
263
264        self.gradient_moving = np.empty(
265            shape=(self.moving_image.shape)+(self.dim,), dtype=floating)
266        for i, grad in enumerate(gradient(self.moving_image)):
267            self.gradient_moving[..., i] = grad
268
269        # Convert moving image's gradient field from voxel to physical space
270        if self.moving_spacing is not None:
271            self.gradient_moving /= self.moving_spacing
272        if self.moving_direction is not None:
273            self.reorient_vector_field(self.gradient_moving,
274                                       self.moving_direction)
275
276        self.gradient_static = np.empty(
277            shape=(self.static_image.shape)+(self.dim,), dtype=floating)
278        for i, grad in enumerate(gradient(self.static_image)):
279            self.gradient_static[..., i] = grad
280
281        # Convert moving image's gradient field from voxel to physical space
282        if self.static_spacing is not None:
283            self.gradient_static /= self.static_spacing
284        if self.static_direction is not None:
285            self.reorient_vector_field(self.gradient_static,
286                                       self.static_direction)
287
288    def free_iteration(self):
289        r"""Frees the resources allocated during initialization
290        """
291        del self.factors
292        del self.gradient_moving
293        del self.gradient_static
294
295    def compute_forward(self):
296        r"""Computes one step bringing the moving image towards the static.
297
298        Computes the update displacement field to be used for registration of
299        the moving image towards the static image
300        """
301        displacement, self.energy = self.compute_forward_step(
302            self.gradient_static, self.factors, self.radius)
303        displacement = np.array(displacement)
304        for i in range(self.dim):
305            displacement[..., i] = ndimage.filters.gaussian_filter(
306                displacement[..., i], self.sigma_diff)
307        return displacement
308
309    def compute_backward(self):
310        r"""Computes one step bringing the static image towards the moving.
311
312        Computes the update displacement field to be used for registration of
313        the static image towards the moving image
314        """
315        displacement, energy = self.compute_backward_step(self.gradient_moving,
316                                                          self.factors,
317                                                          self.radius)
318        displacement = np.array(displacement)
319        for i in range(self.dim):
320            displacement[..., i] = ndimage.filters.gaussian_filter(
321                displacement[..., i], self.sigma_diff)
322        return displacement
323
324    def get_energy(self):
325        r"""Numerical value assigned by this metric to the current image pair
326
327        Returns the Cross Correlation (data term) energy computed at the
328        largest iteration
329        """
330        return self.energy
331
332
333class EMMetric(SimilarityMetric):
334    def __init__(self,
335                 dim,
336                 smooth=1.0,
337                 inner_iter=5,
338                 q_levels=256,
339                 double_gradient=True,
340                 step_type='gauss_newton'):
341        r"""Expectation-Maximization Metric
342
343        Similarity metric based on the Expectation-Maximization algorithm to
344        handle multi-modal images. The transfer function is modeled as a set of
345        hidden random variables that are estimated at each iteration of the
346        algorithm.
347
348        Parameters
349        ----------
350        dim : int (either 2 or 3)
351            the dimension of the image domain
352        smooth : float
353            smoothness parameter, the larger the value the smoother the
354            deformation field
355        inner_iter : int
356            number of iterations to be performed at each level of the multi-
357            resolution Gauss-Seidel optimization algorithm (this is not the
358            number of steps per Gaussian Pyramid level, that parameter must
359            be set for the optimizer, not the metric)
360        q_levels : number of quantization levels (equal to the number of hidden
361            variables in the EM algorithm)
362        double_gradient : boolean
363            if True, the gradient of the expected static image under the moving
364            modality will be added to the gradient of the moving image,
365            similarly, the gradient of the expected moving image under the
366            static modality will be added to the gradient of the static image.
367        step_type : string ('gauss_newton', 'demons')
368            the optimization schedule to be used in the multi-resolution
369            Gauss-Seidel optimization algorithm (not used if Demons Step is
370            selected)
371        """
372        super(EMMetric, self).__init__(dim)
373        self.smooth = smooth
374        self.inner_iter = inner_iter
375        self.q_levels = q_levels
376        self.use_double_gradient = double_gradient
377        self.step_type = step_type
378        self.static_image_mask = None
379        self.moving_image_mask = None
380        self.staticq_means_field = None
381        self.movingq_means_field = None
382        self.movingq_levels = None
383        self.staticq_levels = None
384        self._connect_functions()
385
386    def _connect_functions(self):
387        r"""Assign the methods to be called according to the image dimension
388
389        Assigns the appropriate functions to be called for image quantization,
390        statistics computation and multi-resolution iterations according to the
391        dimension of the input images
392        """
393        if self.dim == 2:
394            self.quantize = em.quantize_positive_2d
395            self.compute_stats = em.compute_masked_class_stats_2d
396            self.reorient_vector_field = vfu.reorient_vector_field_2d
397        elif self.dim == 3:
398            self.quantize = em.quantize_positive_3d
399            self.compute_stats = em.compute_masked_class_stats_3d
400            self.reorient_vector_field = vfu.reorient_vector_field_3d
401        else:
402            raise ValueError('EM Metric not defined for dim. %d' % (self.dim))
403
404        if self.step_type == 'demons':
405            self.compute_step = self.compute_demons_step
406        elif self.step_type == 'gauss_newton':
407            self.compute_step = self.compute_gauss_newton_step
408        else:
409            raise ValueError('Opt. step %s not defined' % (self.step_type))
410
411    def initialize_iteration(self):
412        r"""Prepares the metric to compute one displacement field iteration.
413
414        Pre-computes the transfer functions (hidden random variables) and
415        variances of the estimators. Also pre-computes the gradient of both
416        input images. Note that once the images are transformed to the opposite
417        modality, the gradient of the transformed images can be used with the
418        gradient of the corresponding modality in the same fashion as
419        diff-demons does for mono-modality images. If the flag
420        self.use_double_gradient is True these gradients are averaged.
421        """
422        sampling_mask = self.static_image_mask*self.moving_image_mask
423        self.sampling_mask = sampling_mask
424        staticq, self.staticq_levels, hist = self.quantize(self.static_image,
425                                                           self.q_levels)
426        staticq = np.array(staticq, dtype=np.int32)
427        self.staticq_levels = np.array(self.staticq_levels)
428        staticq_means, staticq_vars = self.compute_stats(sampling_mask,
429                                                         self.moving_image,
430                                                         self.q_levels,
431                                                         staticq)
432        staticq_means[0] = 0
433        self.staticq_means = np.array(staticq_means)
434        self.staticq_variances = np.array(staticq_vars)
435        self.staticq_sigma_sq_field = self.staticq_variances[staticq]
436        self.staticq_means_field = self.staticq_means[staticq]
437
438        self.gradient_moving = np.empty(
439            shape=(self.moving_image.shape)+(self.dim,), dtype=floating)
440
441        for i, grad in enumerate(gradient(self.moving_image)):
442            self.gradient_moving[..., i] = grad
443
444        # Convert moving image's gradient field from voxel to physical space
445        if self.moving_spacing is not None:
446            self.gradient_moving /= self.moving_spacing
447        if self.moving_direction is not None:
448            self.reorient_vector_field(self.gradient_moving,
449                                       self.moving_direction)
450
451        self.gradient_static = np.empty(
452            shape=(self.static_image.shape)+(self.dim,), dtype=floating)
453
454        for i, grad in enumerate(gradient(self.static_image)):
455            self.gradient_static[..., i] = grad
456
457        # Convert moving image's gradient field from voxel to physical space
458        if self.static_spacing is not None:
459            self.gradient_static /= self.static_spacing
460        if self.static_direction is not None:
461            self.reorient_vector_field(self.gradient_static,
462                                       self.static_direction)
463
464        movingq, self.movingq_levels, hist = self.quantize(self.moving_image,
465                                                           self.q_levels)
466        movingq = np.array(movingq, dtype=np.int32)
467        self.movingq_levels = np.array(self.movingq_levels)
468        movingq_means, movingq_variances = self.compute_stats(
469            sampling_mask, self.static_image, self.q_levels, movingq)
470        movingq_means[0] = 0
471        self.movingq_means = np.array(movingq_means)
472        self.movingq_variances = np.array(movingq_variances)
473        self.movingq_sigma_sq_field = self.movingq_variances[movingq]
474        self.movingq_means_field = self.movingq_means[movingq]
475        if self.use_double_gradient:
476            for i, grad in enumerate(gradient(self.staticq_means_field)):
477                self.gradient_moving[..., i] += grad
478
479            for i, grad in enumerate(gradient(self.movingq_means_field)):
480                self.gradient_static[..., i] += grad
481
482    def free_iteration(self):
483        r"""
484        Frees the resources allocated during initialization
485        """
486        del self.sampling_mask
487        del self.staticq_levels
488        del self.movingq_levels
489        del self.staticq_sigma_sq_field
490        del self.staticq_means_field
491        del self.movingq_sigma_sq_field
492        del self.movingq_means_field
493        del self.gradient_moving
494        del self.gradient_static
495
496    def compute_forward(self):
497        """Computes one step bringing the reference image towards the static.
498
499        Computes the forward update field to register the moving image towards
500        the static image in a gradient-based optimization algorithm
501        """
502        return self.compute_step(True)
503
504    def compute_backward(self):
505        r"""Computes one step bringing the static image towards the moving.
506
507        Computes the update displacement field to be used for registration of
508        the static image towards the moving image
509        """
510        return self.compute_step(False)
511
512    def compute_gauss_newton_step(self, forward_step=True):
513        r"""Computes the Gauss-Newton energy minimization step
514
515        Computes the Newton step to minimize this energy, i.e., minimizes the
516        linearized energy function with respect to the
517        regularized displacement field (this step does not require
518        post-smoothing, as opposed to the demons step, which does not include
519        regularization). To accelerate convergence we use the multi-grid
520        Gauss-Seidel algorithm proposed by Bruhn and Weickert et al [Bruhn05]
521
522        Parameters
523        ----------
524        forward_step : boolean
525            if True, computes the Newton step in the forward direction
526            (warping the moving towards the static image). If False,
527            computes the backward step (warping the static image to the
528            moving image)
529
530        Returns
531        -------
532        displacement : array, shape (R, C, 2) or (S, R, C, 3)
533            the Newton step
534
535        References
536        ----------
537        [Bruhn05] Andres Bruhn and Joachim Weickert, "Towards ultimate motion
538                  estimation: combining highest accuracy with real-time
539                  performance", 10th IEEE International Conference on Computer
540                  Vision, 2005. ICCV 2005.
541        """
542        reference_shape = self.static_image.shape
543
544        if forward_step:
545            gradient = self.gradient_static
546            delta = self.staticq_means_field - self.moving_image
547            sigma_sq_field = self.staticq_sigma_sq_field
548        else:
549            gradient = self.gradient_moving
550            delta = self.movingq_means_field - self.static_image
551            sigma_sq_field = self.movingq_sigma_sq_field
552
553        displacement = np.zeros(shape=(reference_shape)+(self.dim,),
554                                dtype=floating)
555
556        if self.dim == 2:
557            self.energy = v_cycle_2d(self.levels_below,
558                                     self.inner_iter, delta,
559                                     sigma_sq_field,
560                                     gradient,
561                                     None,
562                                     self.smooth,
563                                     displacement)
564        else:
565            self.energy = v_cycle_3d(self.levels_below,
566                                     self.inner_iter, delta,
567                                     sigma_sq_field,
568                                     gradient,
569                                     None,
570                                     self.smooth,
571                                     displacement)
572        return displacement
573
574    def compute_demons_step(self, forward_step=True):
575        r"""Demons step for EM metric
576
577        Parameters
578        ----------
579        forward_step : boolean
580            if True, computes the Demons step in the forward direction
581            (warping the moving towards the static image). If False,
582            computes the backward step (warping the static image to the
583            moving image)
584
585        Returns
586        -------
587        displacement : array, shape (R, C, 2) or (S, R, C, 3)
588            the Demons step
589        """
590        sigma_reg_2 = np.sum(self.static_spacing**2)/self.dim
591
592        if forward_step:
593            gradient = self.gradient_static
594            delta_field = self.static_image - self.movingq_means_field
595            sigma_sq_field = self.movingq_sigma_sq_field
596        else:
597            gradient = self.gradient_moving
598            delta_field = self.moving_image - self.staticq_means_field
599            sigma_sq_field = self.staticq_sigma_sq_field
600
601        if self.dim == 2:
602            step, self.energy = em.compute_em_demons_step_2d(delta_field,
603                                                             sigma_sq_field,
604                                                             gradient,
605                                                             sigma_reg_2,
606                                                             None)
607        else:
608            step, self.energy = em.compute_em_demons_step_3d(delta_field,
609                                                             sigma_sq_field,
610                                                             gradient,
611                                                             sigma_reg_2,
612                                                             None)
613        for i in range(self.dim):
614            step[..., i] = ndimage.filters.gaussian_filter(step[..., i],
615                                                           self.smooth)
616        return step
617
618    def get_energy(self):
619        r"""The numerical value assigned by this metric to the current image pair
620
621        Returns the EM (data term) energy computed at the largest
622        iteration
623        """
624        return self.energy
625
626    def use_static_image_dynamics(self, original_static_image, transformation):
627        r"""This is called by the optimizer just after setting the static image.
628
629        EMMetric takes advantage of the image dynamics by computing the
630        current static image mask from the originalstaticImage mask (warped
631        by nearest neighbor interpolation)
632
633        Parameters
634        ----------
635        original_static_image : array, shape (R, C) or (S, R, C)
636            the original static image from which the current static image was
637            generated, the current static image is the one that was provided
638            via 'set_static_image(...)', which may not be the same as the
639            original static image but a warped version of it (even the static
640            image changes during Symmetric Normalization, not only the moving
641            one).
642        transformation : DiffeomorphicMap object
643            the transformation that was applied to the original_static_image
644            to generate the current static image
645        """
646        self.static_image_mask = (original_static_image > 0).astype(np.int32)
647        if transformation is None:
648            return
649        shape = np.array(self.static_image.shape, dtype=np.int32)
650        affine = self.static_affine
651        self.static_image_mask = transformation.transform(
652            self.static_image_mask, 'nearest', None, shape, affine)
653
654    def use_moving_image_dynamics(self, original_moving_image, transformation):
655        r"""This is called by the optimizer just after setting the moving image.
656
657        EMMetric takes advantage of the image dynamics by computing the
658        current moving image mask from the original_moving_image mask (warped
659        by nearest neighbor interpolation)
660
661        Parameters
662        ----------
663        original_moving_image : array, shape (R, C) or (S, R, C)
664            the original moving image from which the current moving image was
665            generated, the current moving image is the one that was provided
666            via 'set_moving_image(...)', which may not be the same as the
667            original moving image but a warped version of it.
668        transformation : DiffeomorphicMap object
669            the transformation that was applied to the original_moving_image
670            to generate the current moving image
671        """
672        self.moving_image_mask = (original_moving_image > 0).astype(np.int32)
673        if transformation is None:
674            return
675        shape = np.array(self.moving_image.shape, dtype=np.int32)
676        affine = self.moving_affine
677        self.moving_image_mask = transformation.transform(
678            self.moving_image_mask, 'nearest', None, shape, affine)
679
680
681class SSDMetric(SimilarityMetric):
682
683    def __init__(self, dim, smooth=4, inner_iter=10, step_type='demons'):
684        r"""Sum of Squared Differences (SSD) Metric
685
686        Similarity metric for (mono-modal) nonlinear image registration defined
687        by the sum of squared differences (SSD)
688
689        Parameters
690        ----------
691        dim : int (either 2 or 3)
692            the dimension of the image domain
693        smooth : float
694            smoothness parameter, the larger the value the smoother the
695            deformation field
696        inner_iter : int
697            number of iterations to be performed at each level of the multi-
698            resolution Gauss-Seidel optimization algorithm (this is not the
699            number of steps per Gaussian Pyramid level, that parameter must
700            be set for the optimizer, not the metric)
701        step_type : string
702            the displacement field step to be computed when 'compute_forward'
703            and 'compute_backward' are called. Either 'demons' or
704            'gauss_newton'
705        """
706        super(SSDMetric, self).__init__(dim)
707        self.smooth = smooth
708        self.inner_iter = inner_iter
709        self.step_type = step_type
710        self.levels_below = 0
711        self._connect_functions()
712
713    def _connect_functions(self):
714        r"""Assign the methods to be called according to the image dimension
715
716        Assigns the appropriate functions to be called for vector field
717        reorientation and displacement field steps according to the
718        dimension of the input images and the select type of step (either
719        Demons or Gauss Newton)
720        """
721        if self.dim == 2:
722            self.reorient_vector_field = vfu.reorient_vector_field_2d
723        elif self.dim == 3:
724            self.reorient_vector_field = vfu.reorient_vector_field_3d
725        else:
726            raise ValueError('SSD Metric not defined for dim. %d' % (self.dim))
727
728        if self.step_type == 'gauss_newton':
729            self.compute_step = self.compute_gauss_newton_step
730        elif self.step_type == 'demons':
731            self.compute_step = self.compute_demons_step
732        else:
733            raise ValueError('Opt. step %s not defined' % (self.step_type))
734
735    def initialize_iteration(self):
736        r"""Prepares the metric to compute one displacement field iteration.
737
738        Pre-computes the gradient of the input images to be used in the
739        computation of the forward and backward steps.
740        """
741        self.gradient_moving = np.empty(
742            shape=(self.moving_image.shape)+(self.dim,), dtype=floating)
743        for i, grad in enumerate(gradient(self.moving_image)):
744            self.gradient_moving[..., i] = grad
745
746        # Convert static image's gradient field from voxel to physical space
747        if self.moving_spacing is not None:
748            self.gradient_moving /= self.moving_spacing
749        if self.moving_direction is not None:
750            self.reorient_vector_field(self.gradient_moving,
751                                       self.moving_direction)
752
753        self.gradient_static = np.empty(
754            shape=(self.static_image.shape)+(self.dim,), dtype=floating)
755        for i, grad in enumerate(gradient(self.static_image)):
756            self.gradient_static[..., i] = grad
757
758        # Convert static image's gradient field from voxel to physical space
759        if self.static_spacing is not None:
760            self.gradient_static /= self.static_spacing
761        if self.static_direction is not None:
762            self.reorient_vector_field(self.gradient_static,
763                                       self.static_direction)
764
765    def compute_forward(self):
766        r"""Computes one step bringing the reference image towards the static.
767
768        Computes the update displacement field to be used for registration of
769        the moving image towards the static image
770        """
771        return self.compute_step(True)
772
773    def compute_backward(self):
774        r"""Computes one step bringing the static image towards the moving.
775
776        Computes the updated displacement field to be used for registration of
777        the static image towards the moving image
778        """
779        return self.compute_step(False)
780
781    def compute_gauss_newton_step(self, forward_step=True):
782        r"""Computes the Gauss-Newton energy minimization step
783
784        Minimizes the linearized energy function (Newton step) defined by the
785        sum of squared differences of corresponding pixels of the input images
786        with respect to the displacement field.
787
788        Parameters
789        ----------
790        forward_step : boolean
791            if True, computes the Newton step in the forward direction
792            (warping the moving towards the static image). If False,
793            computes the backward step (warping the static image to the
794            moving image)
795
796        Returns
797        -------
798        displacement : array, shape = static_image.shape + (3,)
799            if forward_step==True, the forward SSD Gauss-Newton step,
800            else, the backward step
801        """
802        reference_shape = self.static_image.shape
803
804        if forward_step:
805            gradient = self.gradient_static
806            delta_field = self.static_image-self.moving_image
807        else:
808            gradient = self.gradient_moving
809            delta_field = self.moving_image - self.static_image
810
811        displacement = np.zeros(shape=(reference_shape)+(self.dim,),
812                                dtype=floating)
813
814        if self.dim == 2:
815            self.energy = v_cycle_2d(self.levels_below, self.inner_iter,
816                                     delta_field, None, gradient, None,
817                                     self.smooth, displacement)
818        else:
819            self.energy = v_cycle_3d(self.levels_below, self.inner_iter,
820                                     delta_field, None, gradient, None,
821                                     self.smooth, displacement)
822        return displacement
823
824    def compute_demons_step(self, forward_step=True):
825        r"""Demons step for SSD metric
826
827        Computes the demons step proposed by Vercauteren et al.[Vercauteren09]
828        for the SSD metric.
829
830        Parameters
831        ----------
832        forward_step : boolean
833            if True, computes the Demons step in the forward direction
834            (warping the moving towards the static image). If False,
835            computes the backward step (warping the static image to the
836            moving image)
837
838        Returns
839        -------
840        displacement : array, shape (R, C, 2) or (S, R, C, 3)
841            the Demons step
842
843        References
844        ----------
845        [Vercauteren09] Tom Vercauteren, Xavier Pennec, Aymeric Perchant,
846                        Nicholas Ayache, "Diffeomorphic Demons: Efficient
847                        Non-parametric Image Registration", Neuroimage 2009
848        """
849        sigma_reg_2 = np.sum(self.static_spacing**2)/self.dim
850
851        if forward_step:
852            gradient = self.gradient_static
853            delta_field = self.static_image - self.moving_image
854        else:
855            gradient = self.gradient_moving
856            delta_field = self.moving_image - self.static_image
857
858        if self.dim == 2:
859            step, self.energy = ssd.compute_ssd_demons_step_2d(delta_field,
860                                                               gradient,
861                                                               sigma_reg_2,
862                                                               None)
863        else:
864            step, self.energy = ssd.compute_ssd_demons_step_3d(delta_field,
865                                                               gradient,
866                                                               sigma_reg_2,
867                                                               None)
868        for i in range(self.dim):
869            step[..., i] = ndimage.filters.gaussian_filter(step[..., i],
870                                                           self.smooth)
871        return step
872
873    def get_energy(self):
874        r"""The numerical value assigned by this metric to the current image pair
875
876        Returns the Sum of Squared Differences (data term) energy computed at
877        the largest iteration
878        """
879        return self.energy
880
881    def free_iteration(self):
882        r"""
883        Nothing to free for the SSD metric
884        """
885        pass
886
887
888def v_cycle_2d(n, k, delta_field, sigma_sq_field, gradient_field, target,
889               lambda_param, displacement, depth=0):
890    r"""Multi-resolution Gauss-Seidel solver using V-type cycles
891
892    Multi-resolution Gauss-Seidel solver: solves the Gauss-Newton linear system
893    by first filtering (GS-iterate) the current level, then solves for the
894    residual at a coarser resolution and finally refines the solution at the
895    current resolution. This scheme corresponds to the V-cycle proposed by
896    Bruhn and Weickert[Bruhn05].
897
898    Parameters
899    ----------
900    n : int
901        number of levels of the multi-resolution algorithm (it will be called
902        recursively until level n == 0)
903    k : int
904        the number of iterations at each multi-resolution level
905    delta_field : array, shape (R, C)
906        the difference between the static and moving image (the 'derivative
907        w.r.t. time' in the optical flow model)
908    sigma_sq_field : array, shape (R, C)
909        the variance of the gray level value at each voxel, according to the
910        EM model (for SSD, it is 1 for all voxels). Inf and 0 values
911        are processed specially to support infinite and zero variance.
912    gradient_field : array, shape (R, C, 2)
913        the gradient of the moving image
914    target : array, shape (R, C, 2)
915        right-hand side of the linear system to be solved in the Weickert's
916        multi-resolution algorithm
917    lambda_param : float
918        smoothness parameter, the larger its value the smoother the
919        displacement field
920    displacement : array, shape (R, C, 2)
921        the displacement field to start the optimization from
922
923    Returns
924    -------
925    energy : the energy of the EM (or SSD if sigmafield[...]==1) metric at this
926        iteration
927
928    References
929    ----------
930    [Bruhn05] Andres Bruhn and Joachim Weickert, "Towards ultimate motion
931              estimation: combining the highest accuracy with real-time
932              performance", 10th IEEE International Conference on Computer
933              Vision, 2005. ICCV 2005.
934    """
935    # pre-smoothing
936    for i in range(k):
937        ssd.iterate_residual_displacement_field_ssd_2d(delta_field,
938                                                       sigma_sq_field,
939                                                       gradient_field,
940                                                       target,
941                                                       lambda_param,
942                                                       displacement)
943    if n == 0:
944        energy = ssd.compute_energy_ssd_2d(delta_field)
945        return energy
946
947    # solve at coarser grid
948    residual = None
949    residual = ssd.compute_residual_displacement_field_ssd_2d(delta_field,
950                                                              sigma_sq_field,
951                                                              gradient_field,
952                                                              target,
953                                                              lambda_param,
954                                                              displacement,
955                                                              residual)
956    sub_residual = np.array(vfu.downsample_displacement_field_2d(residual))
957    del residual
958    subsigma_sq_field = None
959    if sigma_sq_field is not None:
960        subsigma_sq_field = vfu.downsample_scalar_field_2d(sigma_sq_field)
961    subdelta_field = vfu.downsample_scalar_field_2d(delta_field)
962
963    subgradient_field = np.array(
964        vfu.downsample_displacement_field_2d(gradient_field))
965
966    shape = np.array(displacement.shape).astype(np.int32)
967    half_shape = ((shape[0] + 1) // 2, (shape[1] + 1) // 2, 2)
968    sub_displacement = np.zeros(shape=half_shape,
969                                dtype=floating)
970    sublambda_param = lambda_param*0.25
971    v_cycle_2d(n-1, k, subdelta_field, subsigma_sq_field, subgradient_field,
972               sub_residual, sublambda_param, sub_displacement, depth+1)
973    # displacement += np.array(
974    #    vfu.upsample_displacement_field(sub_displacement, shape))
975    displacement += vfu.resample_displacement_field_2d(sub_displacement,
976                                                       np.array([0.5, 0.5]),
977                                                       shape)
978
979    # post-smoothing
980    for i in range(k):
981        ssd.iterate_residual_displacement_field_ssd_2d(delta_field,
982                                                       sigma_sq_field,
983                                                       gradient_field,
984                                                       target,
985                                                       lambda_param,
986                                                       displacement)
987    energy = ssd.compute_energy_ssd_2d(delta_field)
988    return energy
989
990
991def v_cycle_3d(n, k, delta_field, sigma_sq_field, gradient_field, target,
992               lambda_param, displacement, depth=0):
993    r"""Multi-resolution Gauss-Seidel solver using V-type cycles
994
995    Multi-resolution Gauss-Seidel solver: solves the linear system by first
996    filtering (GS-iterate) the current level, then solves for the residual
997    at a coarser resolution and finally refines the solution at the current
998    resolution. This scheme corresponds to the V-cycle proposed by Bruhn and
999    Weickert[1].
1000    [1] Andres Bruhn and Joachim Weickert, "Towards ultimate motion estimation:
1001        combining highest accuracy with real-time performance",
1002        10th IEEE International Conference on Computer Vision, 2005.
1003        ICCV 2005.
1004
1005    Parameters
1006    ----------
1007    n : int
1008        number of levels of the multi-resolution algorithm (it will be called
1009        recursively until level n == 0)
1010    k : int
1011        the number of iterations at each multi-resolution level
1012    delta_field : array, shape (S, R, C)
1013        the difference between the static and moving image (the 'derivative
1014        w.r.t. time' in the optical flow model)
1015    sigma_sq_field : array, shape (S, R, C)
1016        the variance of the gray level value at each voxel, according to the
1017        EM model (for SSD, it is 1 for all voxels). Inf and 0 values
1018        are processed specially to support infinite and zero variance.
1019    gradient_field : array, shape (S, R, C, 3)
1020        the gradient of the moving image
1021    target : array, shape (S, R, C, 3)
1022        right-hand side of the linear system to be solved in the Weickert's
1023        multi-resolution algorithm
1024    lambda_param : float
1025        smoothness parameter, the larger its value the smoother the
1026        displacement field
1027    displacement : array, shape (S, R, C, 3)
1028        the displacement field to start the optimization from
1029
1030    Returns
1031    -------
1032    energy : the energy of the EM (or SSD if sigmafield[...]==1) metric at this
1033        iteration
1034    """
1035    # pre-smoothing
1036    for i in range(k):
1037        ssd.iterate_residual_displacement_field_ssd_3d(delta_field,
1038                                                       sigma_sq_field,
1039                                                       gradient_field,
1040                                                       target,
1041                                                       lambda_param,
1042                                                       displacement)
1043    if n == 0:
1044        energy = ssd.compute_energy_ssd_3d(delta_field)
1045        return energy
1046    # solve at coarser grid
1047    residual = ssd.compute_residual_displacement_field_ssd_3d(delta_field,
1048                                                              sigma_sq_field,
1049                                                              gradient_field,
1050                                                              target,
1051                                                              lambda_param,
1052                                                              displacement,
1053                                                              None)
1054    sub_residual = np.array(vfu.downsample_displacement_field_3d(residual))
1055    del residual
1056    subsigma_sq_field = None
1057    if sigma_sq_field is not None:
1058        subsigma_sq_field = vfu.downsample_scalar_field_3d(sigma_sq_field)
1059    subdelta_field = vfu.downsample_scalar_field_3d(delta_field)
1060    subgradient_field = np.array(
1061        vfu.downsample_displacement_field_3d(gradient_field))
1062    shape = np.array(displacement.shape).astype(np.int32)
1063    sub_displacement = np.zeros(
1064        shape=((shape[0]+1)//2, (shape[1]+1)//2, (shape[2]+1)//2, 3),
1065        dtype=floating)
1066    sublambda_param = lambda_param*0.25
1067    v_cycle_3d(n-1, k, subdelta_field, subsigma_sq_field, subgradient_field,
1068               sub_residual, sublambda_param, sub_displacement, depth+1)
1069    del subdelta_field
1070    del subsigma_sq_field
1071    del subgradient_field
1072    del sub_residual
1073    displacement += vfu.resample_displacement_field_3d(sub_displacement,
1074                                                       0.5 * np.ones(3),
1075                                                       shape)
1076    del sub_displacement
1077    # post-smoothing
1078    for i in range(k):
1079        ssd.iterate_residual_displacement_field_ssd_3d(delta_field,
1080                                                       sigma_sq_field,
1081                                                       gradient_field,
1082                                                       target,
1083                                                       lambda_param,
1084                                                       displacement)
1085    energy = ssd.compute_energy_ssd_3d(delta_field)
1086    return energy
1087