1import numpy as np
2from dipy.viz import regtools
3import numpy.testing as npt
4import pytest
5from dipy.align.metrics import SSDMetric
6from dipy.align.imwarp import SymmetricDiffeomorphicRegistration
7
8# Conditional import machinery for matplotlib
9from dipy.utils.optpkg import optional_package
10
11_, have_matplotlib, _ = optional_package('matplotlib')
12
13
14@pytest.mark.skipif(not have_matplotlib, reason='Requires Matplotlib')
15def test_plot_2d_diffeomorphic_map():
16    # Test the regtools plotting interface (lightly).
17    mv_shape = (11, 12)
18    moving = np.random.rand(*mv_shape)
19    st_shape = (13, 14)
20    static = np.random.rand(*st_shape)
21    dim = static.ndim
22    metric = SSDMetric(dim)
23    level_iters = [200, 100, 50, 25]
24    sdr = SymmetricDiffeomorphicRegistration(metric,
25                                             level_iters,
26                                             inv_iter=50)
27    mapping = sdr.optimize(static, moving)
28    # Smoke testing of plots
29    ff = regtools.plot_2d_diffeomorphic_map(mapping, 10)
30    # Default shape is static shape, moving shape
31    npt.assert_equal(ff[0].shape, st_shape)
32    npt.assert_equal(ff[1].shape, mv_shape)
33    # Can specify shape
34    ff = regtools.plot_2d_diffeomorphic_map(mapping,
35                                            delta = 10,
36                                            direct_grid_shape=(7, 8),
37                                            inverse_grid_shape=(9, 10))
38    npt.assert_equal(ff[0].shape, (7, 8))
39    npt.assert_equal(ff[1].shape, (9, 10))
40