1from dipy.denoise.enhancement_kernel import EnhancementKernel
2from dipy.denoise.shift_twist_convolution import convolve, convolve_sf
3from dipy.reconst.shm import sh_to_sf, sf_to_sh
4from dipy.core.sphere import Sphere
5from dipy.data import get_sphere
6
7import numpy as np
8import numpy.testing as npt
9
10def test_enhancement_kernel():
11    """ Test if the kernel values are correct by comparison against the values
12    originally calculated by implementation in Mathematica, and at the same time
13    checks the symmetry of the kernel."""
14
15    D33 = 1.0
16    D44 = 0.04
17    t = 1
18    k = EnhancementKernel(D33, D44, t, orientations=0, force_recompute=True)
19
20    y = np.array([0., 0., 0.])
21    v = np.array([0., 0., 1.])
22    orientationlist=[[0., 0., 1.], [-0.0527864, 0.688191, 0.723607],
23      [-0.67082, -0.16246, 0.723607], [-0.0527864, -0.688191, 0.723607],
24      [0.638197, -0.262866, 0.723607], [0.831052, 0.238856, 0.502295],
25      [0.262866, -0.809017, -0.525731], [0.812731, 0.295242, -0.502295],
26      [-0.029644, 0.864188, -0.502295], [-0.831052, 0.238856, -0.502295],
27      [-0.638197, -0.262866, -0.723607], [-0.436009, 0.864188, -0.251148],
28      [-0.687157, -0.681718, 0.251148], [0.67082, -0.688191, 0.276393],
29      [0.67082, 0.688191, 0.276393], [0.947214, 0.16246, -0.276393],
30      [-0.861803, -0.425325, -0.276393]]
31    positionlist= [[-0.108096, 0.0412229, 0.339119], [0.220647, -0.422053, 0.427524],
32      [-0.337432, -0.0644619, -0.340777], [0.172579, -0.217602, -0.292446],
33      [-0.271575, -0.125249, -0.350906], [-0.483807, 0.326651, 0.191993],
34      [-0.480936, -0.0718426, 0.33202], [0.497193, -0.00585659, -0.251344],
35      [0.237737, 0.013634, -0.471988], [0.367569, -0.163581, 0.0723955],
36      [0.47859, -0.143252, 0.318579], [-0.21474, -0.264929, -0.46786],
37      [-0.0684234, 0.0342464, 0.0942475], [0.344272, 0.423119, -0.303866],
38      [0.0430714, 0.216233, -0.308475], [0.386085, 0.127333, 0.0503609],
39      [0.334723, 0.071415, 0.403906]]
40    kernelvalues = [0.10701063104295713, 0.0030052117308328923, 0.003125410084676201,
41      0.0031765819772012613, 0.003127254657020615, 0.0001295130396491743,
42      6.882352014430076e-14, 1.3821277371353332e-13, 1.3951939946082493e-13,
43      1.381612071786285e-13, 5.0861109163441125e-17, 1.0722120295517027e-10,
44      2.425145934791457e-6, 3.557919265806602e-6, 3.6669510385105265e-6,
45      5.97473789679846e-11, 6.155412262223178e-11]
46
47    for p in range(len(orientationlist)):
48        r = np.array(orientationlist[p])
49        x = np.array(positionlist[p])
50        npt.assert_almost_equal(k.evaluate_kernel(x, y, r, v), kernelvalues[p])
51
52
53def test_spike():
54    """ Test if a convolution with a delta spike is equal to the kernel
55    saved in the lookup table."""
56
57    # create kernel
58    D33 = 1.0
59    D44 = 0.04
60    t = 1
61    num_orientations = 5
62    k = EnhancementKernel(D33, D44, t, orientations=num_orientations, force_recompute=True)
63
64    # create a delta spike
65    numorientations = k.get_orientations().shape[0]
66    spike = np.zeros((7, 7, 7, numorientations), dtype=np.float64)
67    spike[3, 3, 3, 0] = 1
68
69    # convolve kernel with delta spike
70    csd_enh = convolve_sf(spike, k, test_mode=True, normalize=False)
71
72    # check if kernel matches with the convolved delta spike
73    totalsum = 0.0
74    for i in range(0, numorientations):
75        totalsum += np.sum(np.array(k.get_lookup_table())[i, 0, :, :, :] - \
76                    np.array(csd_enh)[:, :, :, i])
77    npt.assert_equal(totalsum, 0.0)
78
79def test_normalization():
80    """ Test the normalization routine applied after a convolution"""
81    # create kernel
82    D33 = 1.0
83    D44 = 0.04
84    t = 1
85    num_orientations = 5
86    k = EnhancementKernel(D33, D44, t, orientations=num_orientations, force_recompute=True)
87
88    # create a constant dataset
89    numorientations = k.get_orientations().shape[0]
90    spike = np.ones((7, 7, 7, numorientations), dtype=np.float64)
91
92    # convert dataset to SH
93    spike_sh = sf_to_sh(spike, k.get_sphere(), sh_order=8)
94
95    # convolve kernel with delta spike and apply normalization
96    csd_enh = convolve(spike_sh, k, sh_order=8, test_mode=True, normalize=True)
97
98    # convert dataset to DSF
99    csd_enh_dsf = sh_to_sf(csd_enh, k.get_sphere(), sh_order=8, basis_type=None)
100
101    # test if the normalization is performed correctly
102    npt.assert_almost_equal(np.amax(csd_enh_dsf), np.amax(spike))
103
104def test_kernel_input():
105    """ Test the kernel for inputs of type Sphere, type int and for input None"""
106
107    sph = Sphere(1, 0, 0)
108    D33 = 1.0
109    D44 = 0.04
110    t = 1
111    k = EnhancementKernel(D33, D44, t, orientations=sph, force_recompute=True)
112    npt.assert_equal(k.get_lookup_table().shape, (1, 1, 7, 7, 7))
113
114    num_orientations = 2
115    k = EnhancementKernel(D33, D44, t, orientations=num_orientations, force_recompute=True)
116    npt.assert_equal(k.get_lookup_table().shape, (2, 2, 7, 7, 7))
117
118    k = EnhancementKernel(D33, D44, t, orientations=0, force_recompute=True)
119    npt.assert_equal(k.get_lookup_table().shape, (0, 0, 7, 7, 7))
120
121if __name__ == '__main__':
122    npt.run_module_suite()
123