1"""Various tools related to creating and working with streamlines
2
3This module provides tools for targeting streamlines using ROIs, for making
4connectivity matrices from whole brain fiber tracking and some other tools that
5allow streamlines to interact with image data.
6
7Important Notes
8-----------------
9Dipy uses affine matrices to represent the relationship between streamline
10points, which are defined as points in a continuous 3d space, and image voxels,
11which are typically arranged in a discrete 3d grid. Dipy uses a convention
12similar to nifti files to interpret these affine matrices. This convention is
13that the point at the center of voxel ``[i, j, k]`` is represented by the point
14``[x, y, z]`` where ``[x, y, z, 1] = affine * [i, j, k, 1]``.  Also when the
15phrase "voxel coordinates" is used, it is understood to be the same as ``affine
16= eye(4)``.
17
18As an example, lets take a 2d image where the affine is::
19
20    [[1., 0., 0.],
21     [0., 2., 0.],
22     [0., 0., 1.]]
23
24The pixels of an image with this affine would look something like::
25
26    A------------
27    |   |   |   |
28    | C |   |   |
29    |   |   |   |
30    ----B--------
31    |   |   |   |
32    |   |   |   |
33    |   |   |   |
34    -------------
35    |   |   |   |
36    |   |   |   |
37    |   |   |   |
38    ------------D
39
40And the letters A-D represent the following points in
41"real world coordinates"::
42
43    A = [-.5, -1.]
44    B = [ .5,  1.]
45    C = [ 0.,  0.]
46    D = [ 2.5,  5.]
47
48"""
49
50from functools import wraps
51from warnings import warn
52
53from nibabel.affines import apply_affine
54from scipy.spatial.distance import cdist
55from numpy import ravel_multi_index
56
57from dipy.core.geometry import dist_to_corner
58
59from collections import defaultdict, OrderedDict
60from itertools import combinations, groupby
61
62import numpy as np
63from numpy import (asarray, ceil, empty, sqrt)
64from dipy.tracking import metrics
65from dipy.tracking.vox2track import _streamlines_in_mask
66
67# Import helper functions shared with vox2track
68from dipy.tracking._utils import (_mapping_to_voxel, _to_voxel_coordinates)
69
70
71def density_map(streamlines, affine, vol_dims):
72    """Counts the number of unique streamlines that pass through each voxel.
73
74    Parameters
75    ----------
76    streamlines : iterable
77        A sequence of streamlines.
78    affine : array_like (4, 4)
79        The mapping from voxel coordinates to streamline points.
80        The voxel_to_rasmm matrix, typically from a NIFTI file.
81    vol_dims : 3 ints
82        The shape of the volume to be returned containing the streamlines
83        counts
84
85    Returns
86    -------
87    image_volume : ndarray, shape=vol_dims
88        The number of streamline points in each voxel of volume.
89
90    Raises
91    ------
92    IndexError
93        When the points of the streamlines lie outside of the return volume.
94
95    Notes
96    -----
97    A streamline can pass through a voxel even if one of the points of the
98    streamline does not lie in the voxel. For example a step from [0,0,0] to
99    [0,0,2] passes through [0,0,1]. Consider subsegmenting the streamlines when
100    the edges of the voxels are smaller than the steps of the streamlines.
101
102    """
103    lin_T, offset = _mapping_to_voxel(affine)
104    counts = np.zeros(vol_dims, 'int')
105    for sl in streamlines:
106        inds = _to_voxel_coordinates(sl, lin_T, offset)
107        i, j, k = inds.T
108        # this takes advantage of the fact that numpy's += operator only
109        # acts once even if there are repeats in inds
110        counts[i, j, k] += 1
111    return counts
112
113
114def connectivity_matrix(streamlines, affine, label_volume, inclusive=False,
115                        symmetric=True, return_mapping=False,
116                        mapping_as_streamlines=False):
117    """Counts the streamlines that start and end at each label pair.
118
119    Parameters
120    ----------
121    streamlines : sequence
122        A sequence of streamlines.
123    affine : array_like (4, 4)
124        The mapping from voxel coordinates to streamline coordinates.
125        The voxel_to_rasmm matrix, typically from a NIFTI file.
126    label_volume : ndarray
127        An image volume with an integer data type, where the intensities in the
128        volume map to anatomical structures.
129    inclusive: bool
130        Whether to analyze the entire streamline, as opposed to just the
131        endpoints. Allowing this will increase calculation time and mapping
132        size, especially if mapping_as_streamlines is True. False by default.
133    symmetric : bool, True by default
134        Symmetric means we don't distinguish between start and end points. If
135        symmetric is True, ``matrix[i, j] == matrix[j, i]``.
136    return_mapping : bool, False by default
137        If True, a mapping is returned which maps matrix indices to
138        streamlines.
139    mapping_as_streamlines : bool, False by default
140        If True voxel indices map to lists of streamline objects. Otherwise
141        voxel indices map to lists of integers.
142
143    Returns
144    -------
145    matrix : ndarray
146        The number of connection between each pair of regions in
147        `label_volume`.
148    mapping : defaultdict(list)
149        ``mapping[i, j]`` returns all the streamlines that connect region `i`
150        to region `j`. If `symmetric` is True mapping will only have one key
151        for each start end pair such that if ``i < j`` mapping will have key
152        ``(i, j)`` but not key ``(j, i)``.
153
154    """
155    # Error checking on label_volume
156    kind = label_volume.dtype.kind
157    labels_positive = ((kind == 'u') or
158                       ((kind == 'i') and (label_volume.min() >= 0)))
159    valid_label_volume = (labels_positive and label_volume.ndim == 3)
160    if not valid_label_volume:
161        raise ValueError("label_volume must be a 3d integer array with"
162                         "non-negative label values")
163
164    # If streamlines is an iterator
165    if return_mapping and mapping_as_streamlines:
166        streamlines = list(streamlines)
167
168    if inclusive:
169        # Create ndarray to store streamline connections
170        edges = np.ndarray(shape=(3, 0), dtype=int)
171        lin_T, offset = _mapping_to_voxel(affine)
172        for sl, _ in enumerate(streamlines):
173            # Convert streamline to voxel coordinates
174            entire = _to_voxel_coordinates(streamlines[sl], lin_T, offset)
175            i, j, k = entire.T
176
177            if symmetric:
178                # Create list of all labels streamline passes through
179                entirelabels = list(OrderedDict.fromkeys(label_volume[i, j, k]))
180                # Append all connection combinations with streamline number
181                for comb in combinations(entirelabels, 2):
182                    edges = np.append(edges, [[comb[0]], [comb[1]], [sl]],
183                                      axis=1)
184            else:
185                # Create list of all labels streamline passes through, keeping
186                # order and whether a label was entered multiple times
187                entirelabels = list(groupby(label_volume[i, j, k]))
188                # Append connection combinations along with streamline number,
189                # removing duplicates and connections from a label to itself
190                combs = set(combinations([z[0] for z in entirelabels], 2))
191                for comb in combs:
192                    if comb[0] == comb[1]:
193                        pass
194                    else:
195                        edges = np.append(edges, [[comb[0]], [comb[1]], [sl]],
196                                          axis=1)
197        if symmetric:
198            edges[0:2].sort(0)
199        mx = label_volume.max() + 1
200        matrix = ndbincount(edges[0:2], shape=(mx, mx))
201
202        if symmetric:
203            matrix = np.maximum(matrix, matrix.T)
204        if return_mapping:
205            mapping = defaultdict(list)
206            for i, (a, b, c) in enumerate(edges.T):
207                mapping[a, b].append(c)
208            # Replace each list of indices with the streamlines they index
209            if mapping_as_streamlines:
210                for key in mapping:
211                    mapping[key] = [streamlines[i] for i in mapping[key]]
212
213            return matrix, mapping
214
215        return matrix
216    else:
217        # take the first and last point of each streamline
218        endpoints = [sl[0::len(sl)-1] for sl in streamlines]
219
220        # Map the streamlines coordinates to voxel coordinates
221        lin_T, offset = _mapping_to_voxel(affine)
222        endpoints = _to_voxel_coordinates(endpoints, lin_T, offset)
223
224        # get labels for label_volume
225        i, j, k = endpoints.T
226        endlabels = label_volume[i, j, k]
227        if symmetric:
228            endlabels.sort(0)
229        mx = label_volume.max() + 1
230        matrix = ndbincount(endlabels, shape=(mx, mx))
231        if symmetric:
232            matrix = np.maximum(matrix, matrix.T)
233
234        if return_mapping:
235            mapping = defaultdict(list)
236            for i, (a, b) in enumerate(endlabels.T):
237                mapping[a, b].append(i)
238
239            # Replace each list of indices with the streamlines they index
240            if mapping_as_streamlines:
241                for key in mapping:
242                    mapping[key] = [streamlines[i] for i in mapping[key]]
243
244            # Return the mapping matrix and the mapping
245            return matrix, mapping
246
247        return matrix
248
249
250def ndbincount(x, weights=None, shape=None):
251    """Like bincount, but for nd-indices.
252
253    Parameters
254    ----------
255    x : array_like (N, M)
256        M indices to a an Nd-array
257    weights : array_like (M,), optional
258        Weights associated with indices
259    shape : optional
260        the shape of the output
261    """
262    x = np.asarray(x)
263    if shape is None:
264        shape = x.max(1) + 1
265
266    x = ravel_multi_index(x, shape)
267    out = np.bincount(x, weights, minlength=np.prod(shape))
268    out.shape = shape
269
270    return out
271
272
273def reduce_labels(label_volume):
274    """Reduces an array of labels to the integers from 0 to n with smallest
275    possible n.
276
277    Examples
278    --------
279    >>> labels = np.array([[1, 3, 9],
280    ...                    [1, 3, 8],
281    ...                    [1, 3, 7]])
282    >>> new_labels, lookup = reduce_labels(labels)
283    >>> lookup
284    array([1, 3, 7, 8, 9])
285    >>> new_labels #doctest: +ELLIPSIS
286    array([[0, 1, 4],
287           [0, 1, 3],
288           [0, 1, 2]]...)
289    >>> (lookup[new_labels] == labels).all()
290    True
291    """
292    lookup_table = np.unique(label_volume)
293    label_volume = lookup_table.searchsorted(label_volume)
294    return label_volume, lookup_table
295
296
297def subsegment(streamlines, max_segment_length):
298    """Splits the segments of the streamlines into small segments.
299
300    Replaces each segment of each of the streamlines with the smallest possible
301    number of equally sized smaller segments such that no segment is longer
302    than max_segment_length. Among other things, this can useful for getting
303    streamline counts on a grid that is smaller than the length of the
304    streamline segments.
305
306    Parameters
307    ----------
308    streamlines : sequence of ndarrays
309        The streamlines to be subsegmented.
310    max_segment_length : float
311        The longest allowable segment length.
312
313    Returns
314    -------
315    output_streamlines : generator
316        A set of streamlines.
317
318    Notes
319    -----
320    Segments of 0 length are removed. If unchanged
321
322    Examples
323    --------
324    >>> streamlines = [np.array([[0,0,0],[2,0,0],[5,0,0]])]
325    >>> list(subsegment(streamlines, 3.))
326    [array([[ 0.,  0.,  0.],
327           [ 2.,  0.,  0.],
328           [ 5.,  0.,  0.]])]
329    >>> list(subsegment(streamlines, 1))
330    [array([[ 0.,  0.,  0.],
331           [ 1.,  0.,  0.],
332           [ 2.,  0.,  0.],
333           [ 3.,  0.,  0.],
334           [ 4.,  0.,  0.],
335           [ 5.,  0.,  0.]])]
336    >>> list(subsegment(streamlines, 1.6))
337    [array([[ 0. ,  0. ,  0. ],
338           [ 1. ,  0. ,  0. ],
339           [ 2. ,  0. ,  0. ],
340           [ 3.5,  0. ,  0. ],
341           [ 5. ,  0. ,  0. ]])]
342    """
343    for sl in streamlines:
344        diff = (sl[1:] - sl[:-1])
345        length = sqrt((diff*diff).sum(-1))
346        num_segments = ceil(length/max_segment_length).astype('int')
347
348        output_sl = empty((num_segments.sum()+1, 3), 'float')
349        output_sl[0] = sl[0]
350
351        count = 1
352        for ii in range(len(num_segments)):
353            ns = num_segments[ii]
354            if ns == 1:
355                output_sl[count] = sl[ii+1]
356                count += 1
357            elif ns > 1:
358                small_d = diff[ii]/ns
359                point = sl[ii]
360                for _ in range(ns):
361                    point = point + small_d
362                    output_sl[count] = point
363                    count += 1
364            elif ns == 0:
365                pass
366                # repeated point
367            else:
368                # this should never happen because ns should be a positive
369                # int
370                assert(ns >= 0)
371        yield output_sl
372
373
374def seeds_from_mask(mask, affine, density=[1, 1, 1]):
375    """Create seeds for fiber tracking from a binary mask.
376
377    Seeds points are placed evenly distributed in all voxels of ``mask`` which
378    are ``True``.
379
380    Parameters
381    ----------
382    mask : binary 3d array_like
383        A binary array specifying where to place the seeds for fiber tracking.
384    affine : array, (4, 4)
385        The mapping between voxel indices and the point space for seeds.
386        The voxel_to_rasmm matrix, typically from a NIFTI file.
387        A seed point at the center the voxel ``[i, j, k]``
388        will be represented as ``[x, y, z]`` where
389        ``[x, y, z, 1] == np.dot(affine, [i, j, k , 1])``.
390    density : int or array_like (3,)
391        Specifies the number of seeds to place along each dimension. A
392        ``density`` of `2` is the same as ``[2, 2, 2]`` and will result in a
393        total of 8 seeds per voxel.
394
395    See Also
396    --------
397    random_seeds_from_mask
398
399    Raises
400    ------
401    ValueError
402        When ``mask`` is not a three-dimensional array
403
404    Examples
405    --------
406    >>> mask = np.zeros((3,3,3), 'bool')
407    >>> mask[0,0,0] = 1
408    >>> seeds_from_mask(mask, np.eye(4), [1,1,1])
409    array([[ 0.,  0.,  0.]])
410    """
411    mask = np.array(mask, dtype=bool, copy=False, ndmin=3)
412    if mask.ndim != 3:
413        raise ValueError('mask cannot be more than 3d')
414
415    density = asarray(density, int)
416    if density.size == 1:
417        d = density
418        density = np.empty(3, dtype=int)
419        density.fill(d)
420    elif density.shape != (3,):
421        raise ValueError("density should be in integer array of shape (3,)")
422
423    # Grid of points between -.5 and .5, centered at 0, with given density
424    grid = np.mgrid[0:density[0], 0:density[1], 0:density[2]]
425    grid = grid.T.reshape((-1, 3))
426    grid = grid / density
427    grid += (.5 / density - .5)
428
429    where = np.argwhere(mask)
430
431    # Add the grid of points to each voxel in mask
432    seeds = where[:, np.newaxis, :] + grid[np.newaxis, :, :]
433    seeds = seeds.reshape((-1, 3))
434
435    # Apply the spatial transform
436    if seeds.any():
437        # Use affine to move seeds into real world coordinates
438        seeds = np.dot(seeds, affine[:3, :3].T)
439        seeds += affine[:3, 3]
440
441    return seeds
442
443
444def random_seeds_from_mask(mask, affine, seeds_count=1,
445                           seed_count_per_voxel=True, random_seed=None):
446    """Create randomly placed seeds for fiber tracking from a binary mask.
447
448    Seeds points are placed randomly distributed in voxels of ``mask``
449    which are ``True``.
450    If ``seed_count_per_voxel`` is ``True``, this function is
451    similar to ``seeds_from_mask()``, with the difference that instead of
452    evenly distributing the seeds, it randomly places the seeds within the
453    voxels specified by the ``mask``.
454
455    Parameters
456    ----------
457    mask : binary 3d array_like
458        A binary array specifying where to place the seeds for fiber tracking.
459    affine : array, (4, 4)
460        The mapping between voxel indices and the point space for seeds.
461        The voxel_to_rasmm matrix, typically from a NIFTI file.
462        A seed point at the center the voxel ``[i, j, k]``
463        will be represented as ``[x, y, z]`` where
464        ``[x, y, z, 1] == np.dot(affine, [i, j, k , 1])``.
465    seeds_count : int
466        The number of seeds to generate. If ``seed_count_per_voxel`` is True,
467        specifies the number of seeds to place in each voxel. Otherwise,
468        specifies the total number of seeds to place in the mask.
469    seed_count_per_voxel: bool
470        If True, seeds_count is per voxel, else seeds_count is the total number
471        of seeds.
472    random_seed : int
473        The seed for the random seed generator (numpy.random.seed).
474
475    See Also
476    --------
477    seeds_from_mask
478
479    Raises
480    ------
481    ValueError
482        When ``mask`` is not a three-dimensional array
483
484    Examples
485    --------
486    >>> mask = np.zeros((3,3,3), 'bool')
487    >>> mask[0,0,0] = 1
488    >>> random_seeds_from_mask(mask, np.eye(4), seeds_count=1,
489    ... seed_count_per_voxel=True, random_seed=1)
490    array([[-0.0640051 , -0.47407377,  0.04966248]])
491    >>> random_seeds_from_mask(mask, np.eye(4), seeds_count=6,
492    ... seed_count_per_voxel=True, random_seed=1)
493    array([[-0.0640051 , -0.47407377,  0.04966248],
494           [ 0.0507979 ,  0.20814782, -0.20909526],
495           [ 0.46702984,  0.04723225,  0.47268436],
496           [-0.27800683,  0.37073231, -0.29328084],
497           [ 0.39286015, -0.16802019,  0.32122912],
498           [-0.42369171,  0.27991879, -0.06159077]])
499    >>> mask[0,1,2] = 1
500    >>> random_seeds_from_mask(mask, np.eye(4),
501    ... seeds_count=2, seed_count_per_voxel=True, random_seed=1)
502    array([[-0.0640051 , -0.47407377,  0.04966248],
503           [-0.27800683,  1.37073231,  1.70671916],
504           [ 0.0507979 ,  0.20814782, -0.20909526],
505           [-0.48962585,  1.00187459,  1.99577329]])
506    """
507    mask = np.array(mask, dtype=bool, copy=False, ndmin=3)
508    if mask.ndim != 3:
509        raise ValueError('mask cannot be more than 3d')
510
511    # Randomize the voxels
512    np.random.seed(random_seed)
513    shape = mask.shape
514    mask = mask.flatten()
515    indices = np.arange(len(mask))
516    np.random.shuffle(indices)
517
518    where = [np.unravel_index(i, shape) for i in indices if mask[i] == 1]
519    num_voxels = len(where)
520
521    if not seed_count_per_voxel:
522        # Generate enough seeds per voxel
523        seeds_per_voxel = seeds_count // num_voxels + 1
524    else:
525        seeds_per_voxel = seeds_count
526
527    seeds = []
528    for i in range(1, seeds_per_voxel + 1):
529        for s in where:
530            # Set the random seed with the current seed, the current value of
531            # seeds per voxel and the global random seed.
532            if random_seed is not None:
533                s_random_seed = hash((np.sum(s) + 1) * i + random_seed) \
534                    % (2**32 - 1)
535                np.random.seed(s_random_seed)
536            # Generate random triplet
537            grid = np.random.random(3)
538            seed = s + grid - .5
539            seeds.append(seed)
540    seeds = asarray(seeds)
541
542    if not seed_count_per_voxel:
543        # Select the requested amount
544        seeds = seeds[:seeds_count]
545
546    # Apply the spatial transform
547    if seeds.any():
548        # Use affine to move seeds into real world coordinates
549        seeds = np.dot(seeds, affine[:3, :3].T)
550        seeds += affine[:3, 3]
551
552    return seeds
553
554
555def _with_initialize(generator):
556    """Allows one to write a generator with initialization code.
557
558    All code up to the first yield is run as soon as the generator function is
559    called and the first yield value is ignored.
560    """
561    @wraps(generator)
562    def helper(*args, **kwargs):
563        gen = generator(*args, **kwargs)
564        next(gen)
565        return gen
566
567    return helper
568
569
570@_with_initialize
571def target(streamlines, affine, target_mask, include=True):
572    """Filters streamlines based on whether or not they pass through an ROI.
573
574    Parameters
575    ----------
576    streamlines : iterable
577        A sequence of streamlines. Each streamline should be a (N, 3) array,
578        where N is the length of the streamline.
579    affine : array (4, 4)
580        The mapping between voxel indices and the point space for seeds.
581        The voxel_to_rasmm matrix, typically from a NIFTI file.
582    target_mask : array-like
583        A mask used as a target. Non-zero values are considered to be within
584        the target region.
585    include : bool, default True
586        If True, streamlines passing through `target_mask` are kept. If False,
587        the streamlines not passing through `target_mask` are kept.
588
589    Returns
590    -------
591    streamlines : generator
592        A sequence of streamlines that pass through `target_mask`.
593
594    Raises
595    ------
596    ValueError
597        When the points of the streamlines lie outside of the `target_mask`.
598
599    See Also
600    --------
601    density_map
602    """
603    target_mask = np.array(target_mask, dtype=bool, copy=True)
604    lin_T, offset = _mapping_to_voxel(affine)
605    yield
606    # End of initialization
607
608    for sl in streamlines:
609        try:
610            ind = _to_voxel_coordinates(sl, lin_T, offset)
611            i, j, k = ind.T
612            state = target_mask[i, j, k]
613        except IndexError:
614            raise ValueError("streamlines points are outside of target_mask")
615        if state.any() == include:
616            yield sl
617
618
619@_with_initialize
620def target_line_based(streamlines, affine, target_mask, include=True):
621    """Filters streamlines based on whether or not they pass through a ROI,
622    using a line-based algorithm. Mostly used as a replacement of `target`
623    for compressed streamlines.
624
625    This function never returns single-point streamlines, whatever the
626    value of `include`.
627
628    Parameters
629    ----------
630    streamlines : iterable
631        A sequence of streamlines. Each streamline should be a (N, 3) array,
632        where N is the length of the streamline.
633    affine : array (4, 4)
634        The mapping between voxel indices and the point space for seeds.
635        The voxel_to_rasmm matrix, typically from a NIFTI file.
636    target_mask : array-like
637        A mask used as a target. Non-zero values are considered to be within
638        the target region.
639    include : bool, default True
640        If True, streamlines passing through `target_mask` are kept. If False,
641        the streamlines not passing through `target_mask` are kept.
642
643    Returns
644    -------
645    streamlines : generator
646        A sequence of streamlines that pass through `target_mask`.
647
648    References
649    ----------
650    [Bresenham5] Bresenham, Jack Elton. "Algorithm for computer control of a
651                 digital plotter", IBM Systems Journal, vol 4, no. 1, 1965.
652    [Houde15] Houde et al. How to avoid biased streamlines-based metrics for
653              streamlines with variable step sizes, ISMRM 2015.
654
655    See Also
656    --------
657    dipy.tracking.utils.density_map
658    dipy.tracking.streamline.compress_streamlines
659    """
660    target_mask = np.array(target_mask, dtype=np.uint8, copy=True)
661    lin_T, offset = _mapping_to_voxel(affine)
662    streamline_index = _streamlines_in_mask(
663        streamlines, target_mask, lin_T, offset)
664    yield
665    # End of initialization
666
667    for idx in np.where(streamline_index == [0, 1][include])[0]:
668        yield streamlines[idx]
669
670
671def streamline_near_roi(streamline, roi_coords, tol, mode='any'):
672    """Is a streamline near an ROI.
673
674    Implements the inner loops of the :func:`near_roi` function.
675
676    Parameters
677    ----------
678    streamline : array, shape (N, 3)
679        A single streamline
680    roi_coords : array, shape (M, 3)
681        ROI coordinates transformed to the streamline coordinate frame.
682    tol : float
683        Distance (in the units of the streamlines, usually mm). If any
684        coordinate in the streamline is within this distance from the center
685        of any voxel in the ROI, this function returns True.
686    mode : string
687        One of {"any", "all", "either_end", "both_end"}, where return True
688        if:
689
690        "any" : any point is within tol from ROI.
691
692        "all" : all points are within tol from ROI.
693
694        "either_end" : either of the end-points is within tol from ROI
695
696        "both_end" : both end points are within tol from ROI.
697
698    Returns
699    -------
700    out : boolean
701    """
702    if len(roi_coords) == 0:
703        return False
704    if mode == "any" or mode == "all":
705        s = streamline
706    elif mode == "either_end" or mode == "both_end":
707        # 'end' modes, use a streamline with 2 nodes:
708        s = np.vstack([streamline[0], streamline[-1]])
709    else:
710        e_s = "For determining relationship to an array, you can use "
711        e_s += "one of the following modes: 'any', 'all', 'both_end',"
712        e_s += "'either_end', but you entered: %s." % mode
713        raise ValueError(e_s)
714
715    dist = cdist(s, roi_coords, 'euclidean')
716
717    if mode == "any" or mode == "either_end":
718        return np.min(dist) <= tol
719    else:
720        return np.all(np.min(dist, -1) <= tol)
721
722
723def near_roi(streamlines, affine, region_of_interest, tol=None,
724             mode="any"):
725    """Provide filtering criteria for a set of streamlines based on whether
726    they fall within a tolerance distance from an ROI
727
728    Parameters
729    ----------
730    streamlines : list or generator
731        A sequence of streamlines. Each streamline should be a (N, 3) array,
732        where N is the length of the streamline.
733    affine : array (4, 4)
734        The mapping between voxel indices and the point space for seeds.
735        The voxel_to_rasmm matrix, typically from a NIFTI file.
736    region_of_interest : ndarray
737        A mask used as a target. Non-zero values are considered to be within
738        the target region.
739    tol : float
740        Distance (in the units of the streamlines, usually mm). If any
741        coordinate in the streamline is within this distance from the center
742        of any voxel in the ROI, the filtering criterion is set to True for
743        this streamline, otherwise False. Defaults to the distance between
744        the center of each voxel and the corner of the voxel.
745    mode : string, optional
746        One of {"any", "all", "either_end", "both_end"}, where return True
747        if:
748
749        "any" : any point is within tol from ROI. Default.
750
751        "all" : all points are within tol from ROI.
752
753        "either_end" : either of the end-points is within tol from ROI
754
755        "both_end" : both end points are within tol from ROI.
756
757    Returns
758    -------
759    1D array of boolean dtype, shape (len(streamlines), )
760
761    This contains `True` for indices corresponding to each streamline
762    that passes within a tolerance distance from the target ROI, `False`
763    otherwise.
764    """
765    dtc = dist_to_corner(affine)
766    if tol is None:
767        tol = dtc
768    elif tol < dtc:
769        w_s = "Tolerance input provided would create gaps in your"
770        w_s += " inclusion ROI. Setting to: %s" % dtc
771        warn(w_s)
772        tol = dtc
773
774    roi_coords = np.array(np.where(region_of_interest)).T
775    x_roi_coords = apply_affine(affine, roi_coords)
776
777    # If it's already a list, we can save time by pre-allocating the output
778    if isinstance(streamlines, list):
779        out = np.zeros(len(streamlines), dtype=bool)
780        for ii, sl in enumerate(streamlines):
781            out[ii] = streamline_near_roi(sl, x_roi_coords, tol=tol,
782                                          mode=mode)
783        return out
784    # If it's a generator, we'll need to generate the output into a list
785    else:
786        out = []
787        for sl in streamlines:
788            out.append(streamline_near_roi(sl, x_roi_coords, tol=tol,
789                                           mode=mode))
790
791        return(np.array(out, dtype=bool))
792
793
794def length(streamlines):
795    """
796    Calculate the lengths of many streamlines in a bundle.
797
798    Parameters
799    ----------
800    streamlines : list
801        Each item in the list is an array with 3D coordinates of a streamline.
802
803    Returns
804    -------
805    Iterator object which then computes the length of each
806    streamline in the bundle, upon iteration.
807    """
808
809    return map(metrics.length, streamlines)
810
811
812def unique_rows(in_array, dtype='f4'):
813    """
814    This (quickly) finds the unique rows in an array
815
816    Parameters
817    ----------
818    in_array: ndarray
819        The array for which the unique rows should be found
820
821    dtype: str, optional
822        This determines the intermediate representation used for the
823        values. Should at least preserve the values of the input array.
824
825    Returns
826    -------
827    u_return: ndarray
828       Array with the unique rows of the original array.
829
830    """
831    # Sort input array
832    order = np.lexsort(in_array.T)
833
834    # Apply sort and compare neighbors
835    x = in_array[order]
836    diff_x = np.ones(len(x), dtype=bool)
837    diff_x[1:] = (x[1:] != x[:-1]).any(-1)
838
839    # Reverse sort and return unique rows
840    un_order = order.argsort()
841    diff_in_array = diff_x[un_order]
842    return in_array[diff_in_array]
843
844
845@_with_initialize
846def transform_tracking_output(tracking_output, affine, save_seeds=False):
847    """Applies a linear transformation, given by affine, to streamlines.
848    Parameters
849    ----------
850    streamlines : Streamlines generator
851        Either streamlines (list, ArraySequence) or a tuple with streamlines
852        and seeds together
853    affine : array (4, 4)
854        The mapping between voxel indices and the point space for seeds.
855        The voxel_to_rasmm matrix, typically from a NIFTI file.
856    save_seeds : bool, optional
857        If set, seeds associated to streamlines will be also moved and returned
858    Returns
859    -------
860    streamlines : generator
861        A generator for the sequence of transformed streamlines.
862        If save_seeds is True, also return a generator for the
863        transformed seeds.
864    """
865    if save_seeds:
866        streamlines, seeds = zip(*tracking_output)
867    else:
868        streamlines = tracking_output
869        seeds = None
870
871    lin_T = affine[:3, :3].T.copy()
872    offset = affine[:3, 3].copy()
873    yield
874    # End of initialization
875
876    if seeds is not None:
877        for sl, seed in zip(streamlines, seeds):
878            yield np.dot(sl, lin_T) + offset, np.dot(seed, lin_T) + offset
879    else:
880        for sl in streamlines:
881            yield np.dot(sl, lin_T) + offset
882
883
884def reduce_rois(rois, include):
885    """Reduce multiple ROIs to one inclusion and one exclusion ROI.
886
887    Parameters
888    ----------
889    rois : list or ndarray
890        A list of 3D arrays, each with shape (x, y, z) corresponding to the
891        shape of the brain volume, or a 4D array with shape (n_rois, x, y,
892        z). Non-zeros in each volume are considered to be within the region.
893
894    include : array or list
895        A list or 1D array of boolean marking inclusion or exclusion
896        criteria.
897
898    Returns
899    -------
900    include_roi : boolean 3D array
901        An array marking the inclusion mask.
902
903    exclude_roi : boolean 3D array
904        An array marking the exclusion mask
905
906    Notes
907    -----
908    The include_roi and exclude_roi can be used to perfom the operation: "(A
909    or B or ...) and not (X or Y or ...)", where A, B are inclusion regions
910    and X, Y are exclusion regions.
911
912    """
913    include_roi = np.zeros(rois[0].shape, dtype=bool)
914    exclude_roi = np.zeros(rois[0].shape, dtype=bool)
915
916    for i in range(len(rois)):
917        if include[i]:
918            include_roi |= rois[i]
919        else:
920            exclude_roi |= rois[i]
921
922    return include_roi, exclude_roi
923
924
925def _min_at(a, index, value):
926    index = np.asarray(index)
927    sort_keys = [value] + list(index)
928    order = np.lexsort(sort_keys)
929    index = index[:, order]
930    value = value[order]
931    uniq = np.ones(index.shape[1], dtype=bool)
932    uniq[1:] = (index[:, 1:] != index[:, :-1]).any(axis=0)
933
934    index = index[:, uniq]
935    value = value[uniq]
936
937    a[tuple(index)] = np.minimum(a[tuple(index)], value)
938
939
940try:
941    minimum_at = np.minimum.at
942except AttributeError:
943    minimum_at = _min_at
944
945
946def path_length(streamlines, affine, aoi, fill_value=-1):
947    """ Computes the shortest path, along any streamline, between aoi and
948    each voxel.
949
950    Parameters
951    ----------
952    streamlines : seq of (N, 3) arrays
953        A sequence of streamlines, path length is given in mm along the curve
954        of the streamline.
955    aoi : array, 3d
956        A mask (binary array) of voxels from which to start computing distance.
957    affine : array (4, 4)
958        The mapping between voxel indices and the point space for seeds.
959        The voxel_to_rasmm matrix, typically from a NIFTI file.
960    fill_value : float
961        The value of voxel in the path length map that are not connected to the
962        aoi.
963
964    Returns
965    -------
966    plm : array
967        Same shape as aoi. The minimum distance between every point and aoi
968        along the path of a streamline.
969    """
970    aoi = np.asarray(aoi, dtype=bool)
971
972    # path length map
973    plm = np.empty(aoi.shape, dtype=float)
974    plm[:] = np.inf
975    lin_T, offset = _mapping_to_voxel(affine)
976    for sl in streamlines:
977        seg_ind = _to_voxel_coordinates(sl, lin_T, offset)
978        i, j, k = seg_ind.T
979        # Get where streamlines passes through aoi
980        breaks = aoi[i, j, k]
981        # Where streamline passes aoi, dist is zero
982        i, j, k = seg_ind[breaks].T
983        plm[i, j, k] = 0
984
985        # If a streamline crosses aoi >1, re-start counting distance for each
986        for seg in _as_segments(sl, breaks):
987            i, j, k = _to_voxel_coordinates(seg[1:], lin_T, offset).T
988            # Get the distance, in mm, between streamline points
989            segment_length = np.sqrt(((seg[1:] - seg[:-1]) ** 2).sum(1))
990            dist = segment_length.cumsum()
991            # Updates path length map with shorter distances
992            minimum_at(plm, (i, j, k), dist)
993    if fill_value != np.inf:
994        plm = np.where(plm == np.inf, fill_value, plm)
995    return plm
996
997
998def _part_segments(streamline, break_points):
999    segments = np.split(streamline, break_points.nonzero()[0])
1000    # Skip first segment, all points before first break
1001    # first segment is empty when break_points[0] == 0
1002    segments = segments[1:]
1003    for each in segments:
1004        if len(each) > 1:
1005            yield each
1006
1007
1008def _as_segments(streamline, break_points):
1009    for seg in _part_segments(streamline, break_points):
1010        yield seg
1011    for seg in _part_segments(streamline[::-1], break_points[::-1]):
1012        yield seg
1013