1import numpy as np
2from dipy.data import default_sphere, get_3shell_gtab, get_isbi2013_2shell_gtab
3from dipy.reconst.shore import ShoreModel
4from dipy.reconst.shm import sh_to_sf
5from dipy.direction.peaks import peak_directions
6from dipy.reconst.odf import gfa
7from numpy.testing import (assert_equal,
8                           assert_almost_equal,
9                           run_module_suite)
10from dipy.sims.voxel import sticks_and_ball
11from dipy.core.subdivide_octahedron import create_unit_sphere
12from dipy.core.sphere_stats import angular_similarity
13from dipy.reconst.tests.test_dsi import sticks_and_ball_dummies
14
15
16def test_shore_odf():
17    gtab = get_isbi2013_2shell_gtab()
18
19    # load repulsion 724 sphere
20    sphere = default_sphere
21
22    # load icosahedron sphere
23    sphere2 = create_unit_sphere(5)
24    data, golden_directions = sticks_and_ball(gtab, d=0.0015, S0=100,
25                                              angles=[(0, 0), (90, 0)],
26                                              fractions=[50, 50], snr=None)
27    asm = ShoreModel(gtab, radial_order=6,
28                     zeta=700, lambdaN=1e-8, lambdaL=1e-8)
29    # repulsion724
30    asmfit = asm.fit(data)
31    odf = asmfit.odf(sphere)
32    odf_sh = asmfit.odf_sh()
33    odf_from_sh = sh_to_sf(odf_sh, sphere, 6, basis_type=None,
34                           legacy=True)
35    assert_almost_equal(odf, odf_from_sh, 10)
36
37    directions, _, _ = peak_directions(odf, sphere, .35, 25)
38    assert_equal(len(directions), 2)
39    assert_almost_equal(
40        angular_similarity(directions, golden_directions), 2, 1)
41
42    # 5 subdivisions
43    odf = asmfit.odf(sphere2)
44    directions, _, _ = peak_directions(odf, sphere2, .35, 25)
45    assert_equal(len(directions), 2)
46    assert_almost_equal(
47        angular_similarity(directions, golden_directions), 2, 1)
48
49    sb_dummies = sticks_and_ball_dummies(gtab)
50    for sbd in sb_dummies:
51        data, golden_directions = sb_dummies[sbd]
52        asmfit = asm.fit(data)
53        odf = asmfit.odf(sphere2)
54        directions, _, _ = peak_directions(odf, sphere2, .35, 25)
55        if len(directions) <= 3:
56            assert_equal(len(directions), len(golden_directions))
57        if len(directions) > 3:
58            assert_equal(gfa(odf) < 0.1, True)
59
60
61def test_multivox_shore():
62    gtab = get_3shell_gtab()
63
64    data = np.random.random([20, 30, 1, gtab.gradients.shape[0]])
65    radial_order = 4
66    zeta = 700
67    asm = ShoreModel(gtab, radial_order=radial_order,
68                     zeta=zeta, lambdaN=1e-8, lambdaL=1e-8)
69    asmfit = asm.fit(data)
70    c_shore = asmfit.shore_coeff
71    assert_equal(c_shore.shape[0:3], data.shape[0:3])
72    assert_equal(np.alltrue(np.isreal(c_shore)), True)
73
74
75if __name__ == '__main__':
76    run_module_suite()
77