1import numpy as np 2from dipy.viz import regtools 3import numpy.testing as npt 4import pytest 5from dipy.align.metrics import SSDMetric 6from dipy.align.imwarp import SymmetricDiffeomorphicRegistration 7 8# Conditional import machinery for matplotlib 9from dipy.utils.optpkg import optional_package 10 11_, have_matplotlib, _ = optional_package('matplotlib') 12 13 14@pytest.mark.skipif(not have_matplotlib, reason='Requires Matplotlib') 15def test_plot_2d_diffeomorphic_map(): 16 # Test the regtools plotting interface (lightly). 17 mv_shape = (11, 12) 18 moving = np.random.rand(*mv_shape) 19 st_shape = (13, 14) 20 static = np.random.rand(*st_shape) 21 dim = static.ndim 22 metric = SSDMetric(dim) 23 level_iters = [200, 100, 50, 25] 24 sdr = SymmetricDiffeomorphicRegistration(metric, 25 level_iters, 26 inv_iter=50) 27 mapping = sdr.optimize(static, moving) 28 # Smoke testing of plots 29 ff = regtools.plot_2d_diffeomorphic_map(mapping, 10) 30 # Default shape is static shape, moving shape 31 npt.assert_equal(ff[0].shape, st_shape) 32 npt.assert_equal(ff[1].shape, mv_shape) 33 # Can specify shape 34 ff = regtools.plot_2d_diffeomorphic_map(mapping, 35 delta = 10, 36 direct_grid_shape=(7, 8), 37 inverse_grid_shape=(9, 10)) 38 npt.assert_equal(ff[0].shape, (7, 8)) 39 npt.assert_equal(ff[1].shape, (9, 10)) 40