1import numpy as np
2import numpy.testing as npt
3
4from dipy.core.sphere import unit_octahedron
5from dipy.reconst.shm import SphHarmFit, SphHarmModel
6from dipy.direction import (DeterministicMaximumDirectionGetter,
7                            ProbabilisticDirectionGetter)
8
9
10def test_ProbabilisticDirectionGetter():
11    # Test the constructors and errors of the ProbabilisticDirectionGetter
12
13    class SillyModel(SphHarmModel):
14
15        sh_order = 4
16
17        def fit(self, data, mask=None):
18            coeff = np.zeros(data.shape[:-1] + (15,))
19            return SphHarmFit(self, coeff, mask=None)
20
21    model = SillyModel(gtab=None)
22    data = np.zeros((3, 3, 3, 7))
23
24    # Test if the tracking works on different dtype of the same data.
25    for dtype in [np.float32, np.float64]:
26        fit = model.fit(data.astype(dtype))
27
28        # Sample point and direction
29        point = np.zeros(3)
30        dir = unit_octahedron.vertices[0].copy()
31
32        # make a dg from a fit
33        dg = ProbabilisticDirectionGetter.from_shcoeff(fit.shm_coeff, 90,
34                                                       unit_octahedron)
35        state = dg.get_direction(point, dir)
36        npt.assert_equal(state, 1)
37
38        # Make a dg from a pmf
39        N = unit_octahedron.theta.shape[0]
40        pmf = np.zeros((3, 3, 3, N))
41        dg = ProbabilisticDirectionGetter.from_pmf(pmf, 90, unit_octahedron)
42        state = dg.get_direction(point, dir)
43        npt.assert_equal(state, 1)
44
45        # pmf shape must match sphere
46        bad_pmf = pmf[..., 1:]
47        npt.assert_raises(ValueError, ProbabilisticDirectionGetter.from_pmf,
48                          bad_pmf, 90, unit_octahedron)
49
50        # pmf must have 4 dimensions
51        bad_pmf = pmf[0, ...]
52        npt.assert_raises(ValueError, ProbabilisticDirectionGetter.from_pmf,
53                          bad_pmf, 90, unit_octahedron)
54        # pmf cannot have negative values
55        pmf[0, 0, 0, 0] = -1
56        npt.assert_raises(ValueError, ProbabilisticDirectionGetter.from_pmf,
57                          pmf, 90, unit_octahedron)
58
59        # Check basis_type keyword
60        dg = ProbabilisticDirectionGetter.from_shcoeff(fit.shm_coeff, 90,
61                                                       unit_octahedron,
62                                                       basis_type="tournier07")
63
64        npt.assert_raises(ValueError,
65                          ProbabilisticDirectionGetter.from_shcoeff,
66                          fit.shm_coeff, 90, unit_octahedron,
67                          basis_type="not a basis")
68
69
70def test_DeterministicMaximumDirectionGetter():
71    # Test the DeterministicMaximumDirectionGetter
72
73    dir = unit_octahedron.vertices[-1].copy()
74    point = np.zeros(3)
75    N = unit_octahedron.theta.shape[0]
76
77    # No valid direction
78    pmf = np.zeros((3, 3, 3, N))
79    dg = DeterministicMaximumDirectionGetter.from_pmf(pmf, 90,
80                                                      unit_octahedron)
81    state = dg.get_direction(point, dir)
82    npt.assert_equal(state, 1)
83
84    # Test BF #1566 - bad condition in DeterministicMaximumDirectionGetter
85    pmf = np.zeros((3, 3, 3, N))
86    pmf[0, 0, 0, 0] = 1
87    dg = DeterministicMaximumDirectionGetter.from_pmf(pmf, 0,
88                                                      unit_octahedron)
89    state = dg.get_direction(point, dir)
90    npt.assert_equal(state, 1)
91