1import numpy as np 2import numpy.testing as npt 3 4from dipy.core.sphere import unit_octahedron 5from dipy.reconst.shm import SphHarmFit, SphHarmModel 6from dipy.direction import (DeterministicMaximumDirectionGetter, 7 ProbabilisticDirectionGetter) 8 9 10def test_ProbabilisticDirectionGetter(): 11 # Test the constructors and errors of the ProbabilisticDirectionGetter 12 13 class SillyModel(SphHarmModel): 14 15 sh_order = 4 16 17 def fit(self, data, mask=None): 18 coeff = np.zeros(data.shape[:-1] + (15,)) 19 return SphHarmFit(self, coeff, mask=None) 20 21 model = SillyModel(gtab=None) 22 data = np.zeros((3, 3, 3, 7)) 23 24 # Test if the tracking works on different dtype of the same data. 25 for dtype in [np.float32, np.float64]: 26 fit = model.fit(data.astype(dtype)) 27 28 # Sample point and direction 29 point = np.zeros(3) 30 dir = unit_octahedron.vertices[0].copy() 31 32 # make a dg from a fit 33 dg = ProbabilisticDirectionGetter.from_shcoeff(fit.shm_coeff, 90, 34 unit_octahedron) 35 state = dg.get_direction(point, dir) 36 npt.assert_equal(state, 1) 37 38 # Make a dg from a pmf 39 N = unit_octahedron.theta.shape[0] 40 pmf = np.zeros((3, 3, 3, N)) 41 dg = ProbabilisticDirectionGetter.from_pmf(pmf, 90, unit_octahedron) 42 state = dg.get_direction(point, dir) 43 npt.assert_equal(state, 1) 44 45 # pmf shape must match sphere 46 bad_pmf = pmf[..., 1:] 47 npt.assert_raises(ValueError, ProbabilisticDirectionGetter.from_pmf, 48 bad_pmf, 90, unit_octahedron) 49 50 # pmf must have 4 dimensions 51 bad_pmf = pmf[0, ...] 52 npt.assert_raises(ValueError, ProbabilisticDirectionGetter.from_pmf, 53 bad_pmf, 90, unit_octahedron) 54 # pmf cannot have negative values 55 pmf[0, 0, 0, 0] = -1 56 npt.assert_raises(ValueError, ProbabilisticDirectionGetter.from_pmf, 57 pmf, 90, unit_octahedron) 58 59 # Check basis_type keyword 60 dg = ProbabilisticDirectionGetter.from_shcoeff(fit.shm_coeff, 90, 61 unit_octahedron, 62 basis_type="tournier07") 63 64 npt.assert_raises(ValueError, 65 ProbabilisticDirectionGetter.from_shcoeff, 66 fit.shm_coeff, 90, unit_octahedron, 67 basis_type="not a basis") 68 69 70def test_DeterministicMaximumDirectionGetter(): 71 # Test the DeterministicMaximumDirectionGetter 72 73 dir = unit_octahedron.vertices[-1].copy() 74 point = np.zeros(3) 75 N = unit_octahedron.theta.shape[0] 76 77 # No valid direction 78 pmf = np.zeros((3, 3, 3, N)) 79 dg = DeterministicMaximumDirectionGetter.from_pmf(pmf, 90, 80 unit_octahedron) 81 state = dg.get_direction(point, dir) 82 npt.assert_equal(state, 1) 83 84 # Test BF #1566 - bad condition in DeterministicMaximumDirectionGetter 85 pmf = np.zeros((3, 3, 3, N)) 86 pmf[0, 0, 0, 0] = 1 87 dg = DeterministicMaximumDirectionGetter.from_pmf(pmf, 0, 88 unit_octahedron) 89 state = dg.get_direction(point, dir) 90 npt.assert_equal(state, 1) 91