1
2import numpy as np
3from dipy.testing import assert_true, assert_false
4from numpy.testing import (assert_array_equal, assert_array_almost_equal,
5                           assert_equal, assert_raises)
6import dipy.tracking.vox2track as tvo
7
8
9def tracks_to_expected(tracks, vol_dims):
10    # simulate expected behavior of module
11    vol_dims = np.array(vol_dims, dtype=np.int32)
12    counts = np.zeros(vol_dims, dtype=np.int32)
13    elements = {}
14    for t_no, t in enumerate(tracks):
15        u_ps = set()
16        ti = np.round(t).astype(np.int32)
17        for p_no, p in enumerate(ti):
18            if np.any(p < 0):
19                p[p < 0] = 0
20            too_high = p >= vol_dims
21            if np.any(too_high):
22                p[too_high] = vol_dims[too_high]-1
23            p = tuple(p)
24            if p in u_ps:
25                continue
26            u_ps.add(p)
27            val = t_no
28            if counts[p]:
29                elements[p].append(val)
30            else:
31                elements[p] = [val]
32            counts[p] += 1
33    return counts, elements
34
35
36def test_track_volumes():
37    # simplest case
38    vol_dims = (1, 2, 3)
39    tracks = ([[0, 0, 0],
40               [0, 1, 1]],)
41    tracks = [np.array(t) for t in tracks]
42    ex_counts, ex_els = tracks_to_expected(tracks, vol_dims)
43    tcs, tes = tvo.track_counts(tracks, vol_dims, [1, 1, 1])
44    assert_array_equal(tcs, ex_counts)
45    assert_array_equal(tes, ex_els)
46    # check only counts returned for return_elements=False
47    tcs = tvo.track_counts(tracks, vol_dims, [1, 1, 1], False)
48    assert_array_equal(tcs, ex_counts)
49
50    # non-unique points, non-integer points, points outside
51    vol_dims = (5, 10, 15)
52    tracks = ([[-1, 0, 1],
53               [0, 0.1, 0],
54               [1, 1, 1],
55               [1, 1, 1],
56               [2, 2, 2]],
57              [[0.7, 0, 0],
58               [1, 1, 1],
59               [1, 2, 2],
60               [1, 11, 0]])
61    tracks = [np.array(t) for t in tracks]
62    ex_counts, ex_els = tracks_to_expected(tracks, vol_dims)
63    tcs, tes = tvo.track_counts(tracks, vol_dims, [1, 1, 1])
64    assert_array_equal(tcs, ex_counts)
65    assert_array_equal(tes, ex_els)
66    # points with non-unit voxel sizes
67    vox_sizes = [1.4, 2.1, 3.7]
68    float_tracks = []
69    for t in tracks:
70        float_tracks.append(t * vox_sizes)
71    tcs, tes = tvo.track_counts(float_tracks, vol_dims, vox_sizes)
72    assert_array_equal(tcs, ex_counts)
73    assert_array_equal(tes, ex_els)
74