1import numpy as np
2import scipy as sp
3import nibabel as nib
4import numpy.linalg as npl
5from numpy.testing import (assert_array_equal,
6                           assert_array_almost_equal,
7                           assert_almost_equal,
8                           assert_equal,
9                           assert_raises)
10from dipy.core import geometry as geometry
11from dipy.viz import regtools as rt
12from dipy.align import floating
13from dipy.align import vector_fields as vf
14from dipy.align import imaffine
15from dipy.align.imaffine import AffineInversionError, AffineInvalidValuesError, \
16    AffineMap, _number_dim_affine_matrix
17from dipy.align.transforms import (Transform,
18                                   regtransforms)
19from dipy.align.tests.test_parzenhist import (setup_random_transform,
20                                              sample_domain_regular)
21
22# For each transform type, select a transform factor (indicating how large the
23# true transform between static and moving images will be), a sampling scheme
24# (either a positive integer less than or equal to 100, or None) indicating
25# the percentage (if int) of voxels to be used for estimating the joint PDFs,
26# or dense sampling (if None), and also specify a starting point (to avoid
27# starting from the identity)
28factors = {('TRANSLATION', 2): (2.0, 0.35, np.array([2.3, 4.5])),
29           ('ROTATION', 2): (0.1, None, np.array([0.1])),
30           ('RIGID', 2): (0.1, .50, np.array([0.12, 1.8, 2.7])),
31           ('SCALING', 2): (0.01, None, np.array([1.05])),
32           ('AFFINE', 2): (0.1, .50, np.array([0.99, -0.05, 1.3, 0.05, 0.99,
33                                               2.5])),
34           ('TRANSLATION', 3): (2.0, None, np.array([2.3, 4.5, 1.7])),
35           ('ROTATION', 3): (0.1, 1.0, np.array([0.1, 0.15, -0.11])),
36           ('RIGID', 3): (0.1, None, np.array([0.1, 0.15, -0.11, 2.3, 4.5,
37                                               1.7])),
38           ('SCALING', 3): (0.1, .35, np.array([0.95])),
39           ('AFFINE', 3): (0.1, None, np.array([0.99, -0.05, 0.03, 1.3,
40                                                0.05, 0.99, -0.10, 2.5,
41                                                -0.07, 0.10, 0.99, -1.4]))}
42
43
44def test_transform_centers_of_mass_3d():
45    np.random.seed(1246592)
46    shape = (64, 64, 64)
47    rm = 8
48    sph = vf.create_sphere(shape[0] // 2, shape[1] // 2, shape[2] // 2, rm)
49    moving = np.zeros(shape)
50    # The center of mass will be (16, 16, 16), in image coordinates
51    moving[:shape[0] // 2, :shape[1] // 2, :shape[2] // 2] = sph[...]
52
53    rs = 16
54    # The center of mass will be (32, 32, 32), in image coordinates
55    static = vf.create_sphere(shape[0], shape[1], shape[2], rs)
56
57    # Create arbitrary image-to-space transforms
58    axis = np.array([.5, 2.0, 1.5])
59    t = 0.15  # translation factor
60    trans = np.array([[1, 0, 0, -t * shape[0]],
61                      [0, 1, 0, -t * shape[1]],
62                      [0, 0, 1, -t * shape[2]],
63                      [0, 0, 0, 1]])
64    trans_inv = npl.inv(trans)
65
66    for rotation_angle in [-1 * np.pi / 6.0, 0.0, np.pi / 5.0]:
67        for scale_factor in [0.83, 1.3, 2.07]:  # scale
68            rot = np.zeros(shape=(4, 4))
69            rot[:3, :3] = geometry.rodrigues_axis_rotation(axis,
70                                                           rotation_angle)
71            rot[3, 3] = 1.0
72            scale = np.array([[1 * scale_factor, 0, 0, 0],
73                              [0, 1 * scale_factor, 0, 0],
74                              [0, 0, 1 * scale_factor, 0],
75                              [0, 0, 0, 1]])
76
77            static_grid2world = trans_inv.dot(scale.dot(rot.dot(trans)))
78            moving_grid2world = npl.inv(static_grid2world)
79
80            # Expected translation
81            c_static = static_grid2world.dot((32, 32, 32, 1))[:3]
82            c_moving = moving_grid2world.dot((16, 16, 16, 1))[:3]
83            expected = np.eye(4)
84            expected[:3, 3] = c_moving - c_static
85
86            # Implementation under test
87            actual = imaffine.transform_centers_of_mass(static,
88                                                        static_grid2world,
89                                                        moving,
90                                                        moving_grid2world)
91            assert_array_almost_equal(actual.affine, expected)
92
93
94def test_transform_geometric_centers_3d():
95    # Create arbitrary image-to-space transforms
96    axis = np.array([.5, 2.0, 1.5])
97    t = 0.15  # translation factor
98
99    for theta in [-1 * np.pi / 6.0, 0.0, np.pi / 5.0]:  # rotation angle
100        for s in [0.83, 1.3, 2.07]:  # scale
101            m_shapes = [(256, 256, 128), (255, 255, 127), (64, 127, 142)]
102            for shape_moving in m_shapes:
103                s_shapes = [(256, 256, 128), (255, 255, 127), (64, 127, 142)]
104                for shape_static in s_shapes:
105                    moving = np.ndarray(shape=shape_moving)
106                    static = np.ndarray(shape=shape_static)
107                    trans = np.array([[1, 0, 0, -t * shape_static[0]],
108                                      [0, 1, 0, -t * shape_static[1]],
109                                      [0, 0, 1, -t * shape_static[2]],
110                                      [0, 0, 0, 1]])
111                    trans_inv = npl.inv(trans)
112                    rot = np.zeros(shape=(4, 4))
113                    rot[:3, :3] = geometry.rodrigues_axis_rotation(axis, theta)
114                    rot[3, 3] = 1.0
115                    scale = np.array([[1 * s, 0, 0, 0],
116                                      [0, 1 * s, 0, 0],
117                                      [0, 0, 1 * s, 0],
118                                      [0, 0, 0, 1]])
119
120                    static_grid2world = trans_inv.dot(
121                        scale.dot(rot.dot(trans)))
122                    moving_grid2world = npl.inv(static_grid2world)
123
124                    # Expected translation
125                    c_static = np.array(shape_static, dtype=np.float64) * 0.5
126                    c_static = tuple(c_static)
127                    c_static = static_grid2world.dot(c_static + (1,))[:3]
128                    c_moving = np.array(shape_moving, dtype=np.float64) * 0.5
129                    c_moving = tuple(c_moving)
130                    c_moving = moving_grid2world.dot(c_moving + (1,))[:3]
131                    expected = np.eye(4)
132                    expected[:3, 3] = c_moving - c_static
133
134                    # Implementation under test
135                    actual = imaffine.transform_geometric_centers(
136                        static, static_grid2world, moving, moving_grid2world)
137                    assert_array_almost_equal(actual.affine, expected)
138
139
140def test_transform_origins_3d():
141    # Create arbitrary image-to-space transforms
142    axis = np.array([.5, 2.0, 1.5])
143    t = 0.15  # translation factor
144
145    for theta in [-1 * np.pi / 6.0, 0.0, np.pi / 5.0]:  # rotation angle
146        for s in [0.83, 1.3, 2.07]:  # scale
147            m_shapes = [(256, 256, 128), (255, 255, 127), (64, 127, 142)]
148            for shape_moving in m_shapes:
149                s_shapes = [(256, 256, 128), (255, 255, 127), (64, 127, 142)]
150                for shape_static in s_shapes:
151                    moving = np.ndarray(shape=shape_moving)
152                    static = np.ndarray(shape=shape_static)
153                    trans = np.array([[1, 0, 0, -t * shape_static[0]],
154                                      [0, 1, 0, -t * shape_static[1]],
155                                      [0, 0, 1, -t * shape_static[2]],
156                                      [0, 0, 0, 1]])
157                    trans_inv = npl.inv(trans)
158                    rot = np.zeros(shape=(4, 4))
159                    rot[:3, :3] = geometry.rodrigues_axis_rotation(axis, theta)
160                    rot[3, 3] = 1.0
161                    scale = np.array([[1 * s, 0, 0, 0],
162                                      [0, 1 * s, 0, 0],
163                                      [0, 0, 1 * s, 0],
164                                      [0, 0, 0, 1]])
165
166                    static_grid2world = trans_inv.dot(
167                        scale.dot(rot.dot(trans)))
168                    moving_grid2world = npl.inv(static_grid2world)
169
170                    # Expected translation
171                    c_static = static_grid2world[:3, 3]
172                    c_moving = moving_grid2world[:3, 3]
173                    expected = np.eye(4)
174                    expected[:3, 3] = c_moving - c_static
175
176                    # Implementation under test
177                    actual = imaffine.transform_origins(static,
178                                                        static_grid2world,
179                                                        moving,
180                                                        moving_grid2world)
181                    assert_array_almost_equal(actual.affine, expected)
182
183
184def test_affreg_all_transforms():
185    # Test affine registration using all transforms with typical settings
186
187    # Make sure dictionary entries are processed in the same order regardless
188    # of the platform. Otherwise any random numbers drawn within the loop would
189    # make the test non-deterministic even if we fix the seed before the loop.
190    # Right now, this test does not draw any samples, but we still sort the
191    # entries to prevent future related failures.
192    for ttype in sorted(factors):
193        dim = ttype[1]
194        if dim == 2:
195            nslices = 1
196        else:
197            nslices = 45
198        factor = factors[ttype][0]
199        sampling_pc = factors[ttype][1]
200        trans = regtransforms[ttype]
201        # Shorthand:
202        srt = setup_random_transform
203        static, moving, static_g2w, moving_g2w, smask, mmask, T = srt(
204                                                                      trans,
205                                                                      factor,
206                                                                      nslices,
207                                                                      1.0)
208        # Sum of absolute differences
209        start_sad = np.abs(static - moving).sum()
210        metric = imaffine.MutualInformationMetric(32, sampling_pc)
211        affreg = imaffine.AffineRegistration(metric,
212                                             [1000, 100, 50],
213                                             [3, 1, 0],
214                                             [4, 2, 1],
215                                             'L-BFGS-B',
216                                             None,
217                                             options=None)
218        x0 = trans.get_identity_parameters()
219        affine_map = affreg.optimize(static, moving, trans, x0,
220                                     static_g2w, moving_g2w)
221        transformed = affine_map.transform(moving)
222        # Sum of absolute differences
223        end_sad = np.abs(static - transformed).sum()
224        reduction = 1 - end_sad / start_sad
225        print("%s>>%f" % (ttype, reduction))
226        assert(reduction > 0.9)
227
228    # Verify that exception is raised if level_iters is empty
229    metric = imaffine.MutualInformationMetric(32)
230    assert_raises(ValueError, imaffine.AffineRegistration, metric, [])
231
232
233def test_affreg_defaults():
234    # Test all default arguments with an arbitrary transform
235    # Select an arbitrary transform (all of them are already tested
236    # in test_affreg_all_transforms)
237    transform_name = 'TRANSLATION'
238    dim = 2
239    ttype = (transform_name, dim)
240    aff_options = ['mass', 'voxel-origin', 'centers', None, np.eye(dim + 1)]
241
242    for starting_affine in aff_options:
243        if dim == 2:
244            nslices = 1
245        else:
246            nslices = 45
247        factor = factors[ttype][0]
248        transform = regtransforms[ttype]
249        static, moving, static_grid2world, moving_grid2world, smask, mmask, T = \
250            setup_random_transform(transform, factor, nslices, 1.0)
251        # Sum of absolute differences
252        start_sad = np.abs(static - moving).sum()
253
254        metric = None
255        x0 = None
256        sigmas = None
257        scale_factors = None
258        level_iters = None
259        static_grid2world = None
260        moving_grid2world = None
261        for ss_sigma_factor in [1.0, None]:
262            affreg = imaffine.AffineRegistration(metric,
263                                                 level_iters,
264                                                 sigmas,
265                                                 scale_factors,
266                                                 'L-BFGS-B',
267                                                 ss_sigma_factor,
268                                                 options=None)
269            affine_map = affreg.optimize(static, moving, transform, x0,
270                                         static_grid2world, moving_grid2world,
271                                         starting_affine)
272            transformed = affine_map.transform(moving)
273            # Sum of absolute differences
274            end_sad = np.abs(static - transformed).sum()
275            reduction = 1 - end_sad / start_sad
276            print("%s>>%f" % (ttype, reduction))
277            assert(reduction > 0.9)
278
279            transformed_inv = affine_map.transform_inverse(static)
280            # Sum of absolute differences
281            end_sad = np.abs(moving - transformed_inv).sum()
282            reduction = 1 - end_sad / start_sad
283            print("%s>>%f" % (ttype, reduction))
284            assert(reduction > 0.9)
285
286
287def test_mi_gradient():
288    np.random.seed(2022966)
289    # Test the gradient of mutual information
290    h = 1e-5
291    # Make sure dictionary entries are processed in the same order regardless
292    # of the platform. Otherwise any random numbers drawn within the loop would
293    # make the test non-deterministic even if we fix the seed before the loop:
294    # in this case the samples are drawn with `np.random.randn` below
295
296    for ttype in sorted(factors):
297        transform = regtransforms[ttype]
298        dim = ttype[1]
299        if dim == 2:
300            nslices = 1
301        else:
302            nslices = 45
303        factor = factors[ttype][0]
304        sampling_proportion = factors[ttype][1]
305        theta = factors[ttype][2]
306        # Start from a small rotation
307        start = regtransforms[('ROTATION', dim)]
308        nrot = start.get_number_of_parameters()
309        starting_affine = start.param_to_matrix(0.25 * np.random.randn(nrot))
310        # Get data (pair of images related to each other by an known transform)
311        static, moving, static_g2w, moving_g2w, smask, mmask, M = \
312            setup_random_transform(transform, factor, nslices, 2.0)
313
314        # Prepare a MutualInformationMetric instance
315        mi_metric = imaffine.MutualInformationMetric(32, sampling_proportion)
316        mi_metric.setup(
317            transform,
318            static,
319            moving,
320            starting_affine=starting_affine)
321        # Compute the gradient with the implementation under test
322        actual = mi_metric.gradient(theta)
323
324        # Compute the gradient using finite-differences
325        n = transform.get_number_of_parameters()
326        expected = np.empty(n, dtype=np.float64)
327
328        val0 = mi_metric.distance(theta)
329        for i in range(n):
330            dtheta = theta.copy()
331            dtheta[i] += h
332            val1 = mi_metric.distance(dtheta)
333            expected[i] = (val1 - val0) / h
334
335        dp = expected.dot(actual)
336        enorm = npl.norm(expected)
337        anorm = npl.norm(actual)
338        nprod = dp / (enorm * anorm)
339        assert(nprod >= 0.99)
340
341
342def create_affine_transforms(
343        dim, translations, rotations, scales, rot_axis=None):
344    r""" Creates a list of affine transforms with all combinations of params
345
346    This function is intended to be used for testing only. It generates
347    affine transforms for all combinations of the input parameters in the
348    following order: let T be a translation, R a rotation and S a scale. The
349    generated affine will be:
350
351    A = T.dot(S).dot(R).dot(T^{-1})
352
353    Translation is handled this way because it is convenient to provide
354    the translation parameters in terms of the center of rotation we wish
355    to generate.
356
357    Parameters
358    ----------
359    dim: int (either dim=2 or dim=3)
360        dimension of the affine transforms
361    translations: sequence of dim-tuples
362        each dim-tuple represents a translation parameter
363    rotations: sequence of floats
364        each number represents a rotation angle in radians
365    scales: sequence of floats
366        each number represents a scale
367    rot_axis: rotation axis (used for dim=3 only)
368
369    Returns
370    -------
371    transforms: sequence of (dim + 1)x(dim + 1) matrices
372        each matrix correspond to an affine transform with a combination
373        of the input parameters
374    """
375    transforms = []
376    for t in translations:
377        trans_inv = np.eye(dim + 1)
378        trans_inv[:dim, dim] = -t[:dim]
379        trans = npl.inv(trans_inv)
380        for theta in rotations:  # rotation angle
381            if dim == 2:
382                ct = np.cos(theta)
383                st = np.sin(theta)
384                rot = np.array([[ct, -st, 0],
385                                [st, ct, 0],
386                                [0, 0, 1]])
387            else:
388                rot = np.eye(dim + 1)
389                rot[:3, :3] = geometry.rodrigues_axis_rotation(rot_axis, theta)
390
391            for s in scales:  # scale
392                scale = np.eye(dim + 1) * s
393                scale[dim, dim] = 1
394
395            affine = trans.dot(scale.dot(rot.dot(trans_inv)))
396            transforms.append(affine)
397    return transforms
398
399
400def test_affine_map():
401    np.random.seed(2112927)
402    dom_shape = np.array([64, 64, 64], dtype=np.int32)
403    cod_shape = np.array([80, 80, 80], dtype=np.int32)
404    # Radius of the circle/sphere (testing image)
405    radius = 16
406    # Rotation axis (used for 3D transforms only)
407    rot_axis = np.array([.5, 2.0, 1.5])
408    # Arbitrary transform parameters
409    t = 0.15
410    rotations = [-1 * np.pi / 10.0, 0.0, np.pi / 10.0]
411    scales = [0.9, 1.0, 1.1]
412    for dim in [2, 3]:
413        # Setup current dimension
414        if dim == 2:
415            # Create image of a circle
416            img = vf.create_circle(cod_shape[0], cod_shape[1], radius)
417            oracle_linear = vf.transform_2d_affine
418            oracle_nn = vf.transform_2d_affine_nn
419        else:
420            # Create image of a sphere
421            img = vf.create_sphere(cod_shape[0], cod_shape[1], cod_shape[2],
422                                   radius)
423            oracle_linear = vf.transform_3d_affine
424            oracle_nn = vf.transform_3d_affine_nn
425        img = np.array(img)
426        # Translation is the only parameter differing for 2D and 3D
427        translations = [t * dom_shape[:dim]]
428        # Generate affine transforms
429        gt_affines = create_affine_transforms(dim, translations, rotations,
430                                              scales, rot_axis)
431        # Include the None case
432        gt_affines.append(None)
433
434        # testing str/format/repr
435        for affine_mat in gt_affines:
436            aff_map = AffineMap(affine_mat)
437            assert_equal(str(aff_map), aff_map.__str__())
438            assert_equal(repr(aff_map), aff_map.__repr__())
439            for spec in ['f', 'r', 't', '']:
440                assert_equal(format(aff_map, spec), aff_map.__format__(spec))
441
442        for affine in gt_affines:
443
444            # make both domain point to the same physical region
445            # It's ok to use the same transform, we just want to test
446            # that this information is actually being considered
447            domain_grid2world = affine
448            codomain_grid2world = affine
449            grid2grid_transform = affine
450
451            # Evaluate the transform with vector_fields module (already tested)
452            expected_linear = oracle_linear(img, dom_shape[:dim],
453                                            grid2grid_transform)
454            expected_nn = oracle_nn(img, dom_shape[:dim], grid2grid_transform)
455
456            # Evaluate the transform with the implementation under test
457            affine_map = imaffine.AffineMap(affine,
458                                            dom_shape[:dim], domain_grid2world,
459                                            cod_shape[:dim],
460                                            codomain_grid2world)
461            actual_linear = affine_map.transform(img, interpolation='linear')
462            actual_nn = affine_map.transform(img, interpolation='nearest')
463            assert_array_almost_equal(actual_linear, expected_linear)
464            assert_array_almost_equal(actual_nn, expected_nn)
465
466            # Test set_affine with valid matrix
467            affine_map.set_affine(affine)
468            if affine is None:
469                assert(affine_map.affine is None)
470                assert(affine_map.affine_inv is None)
471            else:
472                # compatibility with previous versions
473                assert_array_equal(affine, affine_map.affine)
474                # new getter
475                new_copy_affine = affine_map.affine
476                # value must be the same
477                assert_array_equal(affine, new_copy_affine)
478                # but not its reference
479                assert id(affine) != id(new_copy_affine)
480                actual = affine_map.affine.dot(affine_map.affine_inv)
481                assert_array_almost_equal(actual, np.eye(dim + 1))
482
483            # Evaluate via the inverse transform
484
485            # AffineMap will use the inverse of the input matrix when we call
486            # `transform_inverse`. Since the inverse of the inverse of a matrix
487            # is not exactly equal to the original matrix (numerical
488            #  limitations) we need to invert the matrix twice to make sure
489            # the oracle and the implementation under test apply the same
490            # transform
491            aff_inv = None if affine is None else npl.inv(affine)
492            aff_inv_inv = None if aff_inv is None else npl.inv(aff_inv)
493            expected_linear = oracle_linear(img, dom_shape[:dim],
494                                            aff_inv_inv)
495            expected_nn = oracle_nn(img, dom_shape[:dim], aff_inv_inv)
496
497            affine_map = imaffine.AffineMap(aff_inv,
498                                            cod_shape[:dim],
499                                            codomain_grid2world,
500                                            dom_shape[:dim], domain_grid2world)
501            actual_linear = affine_map.transform_inverse(
502                img, interpolation='linear')
503            actual_nn = affine_map.transform_inverse(img,
504                                                     interpolation='nearest')
505            assert_array_almost_equal(actual_linear, expected_linear)
506            assert_array_almost_equal(actual_nn, expected_nn)
507
508        # Verify AffineMap can not be created with non-square matrix
509        non_square_shapes = [np.zeros((dim, dim + 1), dtype=np.float64),
510                             np.zeros((dim + 1, dim), dtype=np.float64)]
511        for nsq in non_square_shapes:
512            assert_raises(AffineInversionError, AffineMap, nsq)
513
514        # Verify incorrect augmentations are caught
515        for affine_mat in gt_affines:
516            aff_map = AffineMap(affine_mat)
517            if affine_mat is None:
518                continue
519            bad_aug = aff_map.affine
520            # no zeros in the first n-1 columns on last row
521            bad_aug[-1, :] = 1
522            assert_raises(AffineInvalidValuesError, AffineMap, bad_aug)
523
524            bad_aug = aff_map.affine
525            bad_aug[-1, -1] = 0  # lower right not 1
526            assert_raises(AffineInvalidValuesError, AffineMap, bad_aug)
527
528        # Verify AffineMap cannot be created with a non-invertible matrix
529        invalid_nan = np.zeros((dim + 1, dim + 1), dtype=np.float64)
530        invalid_nan[1, 1] = np.nan
531        invalid_zeros = np.zeros((dim + 1, dim + 1), dtype=np.float64)
532        assert_raises(
533            imaffine.AffineInvalidValuesError,
534            imaffine.AffineMap,
535            invalid_nan)
536        assert_raises(
537            AffineInvalidValuesError,
538            imaffine.AffineMap,
539            invalid_zeros)
540
541        # Test exception is raised when the affine transform matrix is not
542        # valid
543        invalid_shape = np.eye(dim)
544        affmap_invalid_shape = imaffine.AffineMap(invalid_shape,
545                                                  dom_shape[:dim], None,
546                                                  cod_shape[:dim], None)
547        assert_raises(ValueError, affmap_invalid_shape.transform, img)
548        assert_raises(ValueError, affmap_invalid_shape.transform_inverse, img)
549
550        # Verify exception is raised when sampling info is not provided
551        valid = np.eye(3)
552        affmap_invalid_shape = imaffine.AffineMap(valid)
553        assert_raises(ValueError, affmap_invalid_shape.transform, img)
554        assert_raises(ValueError, affmap_invalid_shape.transform_inverse, img)
555
556        # Verify exception is raised when requesting an invalid interpolation
557        assert_raises(ValueError, affine_map.transform, img, 'invalid')
558        assert_raises(ValueError, affine_map.transform_inverse, img, 'invalid')
559
560        # Verify exception is raised when attempting to warp an image of
561        # invalid dimension
562        for dim in [2, 3]:
563            affine_map = imaffine.AffineMap(np.eye(dim),
564                                            cod_shape[:dim], None,
565                                            dom_shape[:dim], None)
566            for sh in [(2,), (2, 2, 2, 2)]:
567                img = np.zeros(sh)
568                assert_raises(ValueError, affine_map.transform, img)
569                assert_raises(ValueError, affine_map.transform_inverse, img)
570            aff_sing = np.zeros((dim + 1, dim + 1))
571            aff_nan = np.zeros((dim + 1, dim + 1))
572            aff_nan[...] = np.nan
573            aff_inf = np.zeros((dim + 1, dim + 1))
574            aff_inf[...] = np.inf
575
576            assert_raises(
577                AffineInvalidValuesError,
578                affine_map.set_affine,
579                aff_sing)
580            assert_raises(AffineInvalidValuesError, affine_map.set_affine,
581                          aff_nan)
582            assert_raises(AffineInvalidValuesError, affine_map.set_affine,
583                          aff_inf)
584
585    # Verify AffineMap can not be created with non-2D matrices : len(shape) != 2
586    for dim_not_2 in range(10):
587        if dim_not_2 != _number_dim_affine_matrix:
588            mat_large_dim = np.random.random([2]*dim_not_2)
589            assert_raises(AffineInversionError, AffineMap, mat_large_dim)
590
591
592def test_MIMetric_invalid_params():
593    transform = regtransforms[('AFFINE', 3)]
594    static = np.random.rand(20, 20, 20)
595    moving = np.random.rand(20, 20, 20)
596    n = transform.get_number_of_parameters()
597    sampling_proportion = 0.3
598    theta_sing = np.zeros(n)
599    theta_nan = np.zeros(n)
600    theta_nan[...] = np.nan
601    theta_inf = np.zeros(n)
602    theta_nan[...] = np.inf
603
604    mi_metric = imaffine.MutualInformationMetric(32, sampling_proportion)
605    mi_metric.setup(transform, static, moving)
606    for theta in [theta_sing, theta_nan, theta_inf]:
607        # Test metric value at invalid params
608        actual_val = mi_metric.distance(theta)
609        assert(np.isinf(actual_val))
610
611        # Test gradient at invalid params
612        expected_grad = np.zeros(n)
613        actual_grad = mi_metric.gradient(theta)
614        assert_equal(actual_grad, expected_grad)
615
616        # Test both
617        actual_val, actual_grad = mi_metric.distance_and_gradient(theta)
618        assert(np.isinf(actual_val))
619        assert_equal(actual_grad, expected_grad)
620